Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test cases of activation functions #162

Merged
merged 4 commits into from
Feb 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "Zygote"]
16 changes: 9 additions & 7 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ elu(x, α = one(x)) = ifelse(x ≥ 0, x/one(x), α * (exp(x) - one(x)))
activation function.
"""
function gelu(x::Real)
λ = oftype(x/1, √(2/π))
p = oftype(x/1, π)
λ = oftype(x/1, √(2/p))
α = oftype(x/1, 0.044715)
h = oftype(x/1, 0.5)
h * x * (one(x) + tanh(λ * (x + α * x^3)))
Expand Down Expand Up @@ -126,12 +127,6 @@ Return `log(cosh(x))` which is computed in a numerically stable way.
"""
logcosh(x::T) where T = x + softplus(-2x) - log(convert(T, 2))

# Provide an informative error message if activation functions are called with an array
for f in (:σ, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh)
@eval $(f)(x::AbstractArray, args...) =
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
end


"""
mish(x) = x * tanh(softplus(x))
Expand All @@ -140,3 +135,10 @@ Self Regularized Non-Monotonic Neural Activation Function
See [Mish: A Self Regularized Non-Monotonic Neural Activation Function](https://arxiv.org/abs/1908.08681).
"""
mish(x::Real) = x * tanh(softplus(x))


# Provide an informative error message if activation functions are called with an array
for f in (:σ, :σ_stable, :logσ, :relu, :leakyrelu, :elu, :gelu, :swish, :selu, :softsign, :softplus, :logcosh, :mish)
@eval $(f)(x::AbstractArray, args...) =
error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.")
end
19 changes: 17 additions & 2 deletions test/activation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using NNlib, Test
using NNlib, Test, Zygote

ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh];
ACTIVATION_FUNCTIONS = [σ, relu, leakyrelu, elu, gelu, swish, selu, softplus, softsign, logcosh, mish];

function test_value_float_precision_preserving(a)
@testset "$(a): " begin
Expand All @@ -24,6 +24,17 @@ function test_value_int_input_forces_float64(a)
end
end

function test_gradient_float_precision_preserving(a)
@testset "$(a): " begin
for T in [Float32, Float64]
for val in [-10, -1, 0, 1, 10]
val = @inferred a'(T(val))
@test typeof(val) == T
end
end
end
end

@testset "Activation Functions" begin
@test σ(0.0) == 0.5
@test relu(0.0) == 0.0
Expand Down Expand Up @@ -83,6 +94,10 @@ end
@test typeof(relu(Int32(1))) == Int32
end
end

@testset "Float gradient inference" begin
test_gradient_float_precision_preserving.(ACTIVATION_FUNCTIONS)
end

@testset "softmax" begin
xs = rand(5,5)
Expand Down