diff --git a/Project.toml b/Project.toml index 4f2495d05..cd04fd278 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.6" +version = "0.7.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 7f7562dc6..04bb97bb2 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -134,9 +134,14 @@ let @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) @scalar_rule x / y (inv(y), -((x / y) / y)) - #log(complex(x)) is require so it give correct complex answer for x<0 + #log(complex(x)) is required so it gives correct complex answer for x<0 @scalar_rule(x ^ y, - (ifelse(iszero(y), zero(Ω), y * x ^ (y - 1)), Ω * log(complex(x))), + (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))), + ) + # x^y for x < 0 errors when y is not an integer, but then derivative wrt y + # is undefined, so we adopt subgradient convention and set derivative to 0. + @scalar_rule(x::Real ^ y::Real, + (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(ifelse(x ≤ 0, one(x), x))), ) @scalar_rule( rem(x, y), diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 0c19b8277..97b977585 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -117,17 +117,6 @@ rrule_test(mod, Δz, (x, x̄), (y, ȳ)) end - @testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64) - # for real x and n, x must be >0 - x = T <: Real ? 15rand() : 15randn(ComplexF64) - Δx, x̄ = 10rand(T, 2) - y, Δy, ȳ = rand(T, 3) - Δz = rand(T) - - frule_test(^, (x, Δx), (y, Δy)) - rrule_test(^, Δz, (x, x̄), (y, ȳ)) - end - @testset "identity" for T in (Float64, ComplexF64) frule_test(identity, (randn(T), randn(T))) frule_test(identity, (randn(T, 4), randn(T, 4))) diff --git a/test/rulesets/Base/fastmath_able.jl b/test/rulesets/Base/fastmath_able.jl index 36dd4c2ea..8ccdbe67d 100644 --- a/test/rulesets/Base/fastmath_able.jl +++ b/test/rulesets/Base/fastmath_able.jl @@ -139,6 +139,34 @@ const FASTABLE_AST = quote frule_test(f, (x, Δx), (y, Δy)) rrule_test(f, Δz, (x, x̄), (y, ȳ)) end + + @testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64) + # for real x and n, x must be >0 + x = T <: Real ? 15rand() : 15randn(ComplexF64) + Δx, x̄ = 10rand(T, 2) + y, Δy, ȳ = rand(T, 3) + Δz = rand(T) + + frule_test(^, (x, Δx), (y, Δy)) + rrule_test(^, Δz, (x, x̄), (y, ȳ)) + + T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin + # finite differences doesn't work for x < 0, so we check manually + x, y = -10rand(T), 3 + Δx = randn(T) + Δy = randn(T) + Δz = randn(T) + + @test frule((Zero(), Δx, Δy), ^, x, y)[2] ≈ Δx * y * x^(y - 1) + @test frule((Zero(), Δx, Δy), ^, zero(x), y)[2] ≈ 0 + _, ∂x, ∂y = rrule(^, x, y)[2](Δz) + @test ∂x ≈ Δz * y * x^(y - 1) + @test ∂y ≈ 0 + _, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz) + @test ∂x ≈ 0 + @test ∂y ≈ 0 + end + end end @testset "sign" begin