From a6647659763ba3e398e72d83b5b4fa11b3ce2e20 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 2 Dec 2021 21:58:09 -0500 Subject: [PATCH 1/3] improve activation functions --- src/activations.jl | 35 +++++++++++++++++++++++++++-------- test/activations.jl | 4 ++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 3f49978a5..291ea888d 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -292,7 +292,7 @@ julia> elu(-10f0, 2) -1.9999092f0 ``` """ -elu(x, α=1) = ifelse(x ≥ 0, float(x), α * (exp(x) - 1)) +elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath α * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + α) @@ -318,11 +318,14 @@ julia> lineplot(gelu, -2, 2, height=7) """ function gelu(x) α = oftf(x, 0.044715) - λ = oftf(x, gelu_λ) - x/2 * (1 + tanh(λ * (x + α * x^3))) + # λ = oftf(x, gelu_λ) + # x/2 * (1 + tanh(λ * (x + α * x^3))) + λλ = oftf(x, gelu_2λ) + x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x))) # This is faster & more accurate end const gelu_λ = √(2 / π) +const gelu_2λ = √(8 / π) """ swish(x) = x * σ(x) @@ -345,7 +348,7 @@ julia> lineplot(swish, -2, 2, height=7) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -swish(x) = x * σ(x) +@inline swish(x) = x * sigmoid_fast(x) """ lisht(x) = x * tanh(x) @@ -368,7 +371,7 @@ julia> lineplot(lisht, -2, 2, height=7) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -lisht(x) = x * tanh(x) +lisht(x) = x * tanh_fast(x) """ selu(x) = λ * (x ≥ 0 ? x : α * (exp(x) - 1)) @@ -400,7 +403,7 @@ julia> selu(-10f0) function selu(x) λ = oftf(x, selu_λ) α = oftf(x, selu_α) - λ * ifelse(x > 0, x, α * (exp(x) - 1)) + λ * ifelse(x > 0, x, @fastmath α * (exp(x) - 1)) end const selu_λ = 1.0507009873554804934193349852946 @@ -610,7 +613,7 @@ julia> tanhshrink.((-10f0, 10f0)) (-9.0f0, 9.0f0) ``` """ -tanhshrink(x) = x - tanh(x) +tanhshrink(x) = x - tanh_fast(x) """ softshrink(x, λ=0.5) = @@ -649,7 +652,11 @@ julia> softshrink.((-10f0, 10f0)) (-9.5f0, 9.5f0) ``` """ -softshrink(x, λ=oftf(x, 0.5)) = min(max(0, x - λ), x + λ) +function softshrink(x, λ=oftf(x, 0.5)) + lo = x - λ + hi = x + λ + ifelse(hi > 0, ifelse(lo < 0, zero(hi), lo), hi) +end # Provide an informative error message if activation functions are called with an array for f in ACTIVATIONS @@ -734,6 +741,18 @@ end sigmoid_fast(x::Float16) = sigmoid(x) # sigmoid_fast is extremely badly behaved at large x +""" + NNlib.fast_act(f, [x::AbstractArray]) + +Replaces `f == tanh` with [`tanh_fast`](@ref), etc. + +Takes an optional 2nd argument, so that you can disable +this replacement for some array or element types. +""" +@inline fast_act(f::F, ::AbstractArray = 1:0) where {F<:Function} = f +@inline fast_act(::typeof(tanh), ::AbstractArray = 1:0) = tanh_fast +@inline fast_act(::typeof(sigmoid), ::AbstractArray = 1:0) = sigmoid_fast + ## Define rrules for some activation functions, along with the ## broadcasted rrule activation functions. ## TODO: add to the lists below all activations. diff --git a/test/activations.jl b/test/activations.jl index 0a0c5a2f5..0e8fb80a3 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -89,12 +89,12 @@ end @test relu6(-1.0) == 0.0 @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 -@test gelu(-1.0) == -0.15880800939172324 +@test gelu(-1.0) ≈ -0.15880800939172324 @test swish(-1.0) == -sigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @test softplus(-1.0) ≈ log(exp(-1.0) + 1.0) @test softsign(-1.0) == -0.5 -@test selu(-1.0) == 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0) +@test selu(-1.0) ≈ 1.0507009873554804934193349852946 * 1.6732632423543772848170429916717 * (exp(-1.0) - 1.0) @test celu(-1.0) == exp(-1.0) - 1 @test trelu(-1.0) == 0.0 @test log(cosh(-1.0)) ≈ log(cosh(-1.0)) From 3edf26cfd46f472715500e5fb36cc8d9c4272eea Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 23:19:18 -0500 Subject: [PATCH 2/3] comments --- src/activations.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/activations.jl b/src/activations.jl index 291ea888d..3a4bd2776 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -319,7 +319,7 @@ julia> lineplot(gelu, -2, 2, height=7) function gelu(x) α = oftf(x, 0.044715) # λ = oftf(x, gelu_λ) - # x/2 * (1 + tanh(λ * (x + α * x^3))) + # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference λλ = oftf(x, gelu_2λ) x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x))) # This is faster & more accurate end @@ -692,6 +692,8 @@ julia> hard_tanh(0.5f0) ``` """ @inline function tanh_fast(x::Float32) + # This method added in NNlib.jl#345 by @mcabbott and @oscardssmith, + # with coeffiecients found using Remez.jl x2 = abs2(x) n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8)) d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7)) From ed5cf941d1b5c9059e0926915300dba6b47c0454 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 10 Jan 2022 23:22:17 -0500 Subject: [PATCH 3/3] version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1f6438e64..c27ad08d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.30" +version = "0.7.31" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"