Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,14 @@ let
@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (inv(y), -((x / y) / y))
#log(complex(x)) is require so it give correct complex answer for x<0
#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
(ifelse(iszero(y), zero(Ω), y * x ^ (y - 1)), Ω * log(complex(x))),
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
)
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
# is undefined, so we adopt subgradient convention and set derivative to 0.
@scalar_rule(x::Real ^ y::Real,
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(ifelse(x ≤ 0, one(x), x))),
)
@scalar_rule(
rem(x, y),
Expand Down
11 changes: 0 additions & 11 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,6 @@
rrule_test(mod, Δz, (x, x̄), (y, ȳ))
end

@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
# for real x and n, x must be >0
x = T <: Real ? 15rand() : 15randn(ComplexF64)
Δx, x̄ = 10rand(T, 2)
y, Δy, ȳ = rand(T, 3)
Δz = rand(T)

frule_test(^, (x, Δx), (y, Δy))
rrule_test(^, Δz, (x, x̄), (y, ȳ))
end

@testset "identity" for T in (Float64, ComplexF64)
frule_test(identity, (randn(T), randn(T)))
frule_test(identity, (randn(T, 4), randn(T, 4)))
Expand Down
28 changes: 28 additions & 0 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,34 @@ const FASTABLE_AST = quote
frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
# for real x and n, x must be >0
x = T <: Real ? 15rand() : 15randn(ComplexF64)
Δx, x̄ = 10rand(T, 2)
y, Δy, ȳ = rand(T, 3)
Δz = rand(T)

frule_test(^, (x, Δx), (y, Δy))
rrule_test(^, Δz, (x, x̄), (y, ȳ))

T <: Real && @testset "discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
# finite differences doesn't work for x < 0, so we check manually
x, y = -10rand(T), 3
Δx = randn(T)
Δy = randn(T)
Δz = randn(T)

@test frule((Zero(), Δx, Δy), ^, x, y)[2] ≈ Δx * y * x^(y - 1)
@test frule((Zero(), Δx, Δy), ^, zero(x), y)[2] ≈ 0
_, ∂x, ∂y = rrule(^, x, y)[2](Δz)
@test ∂x ≈ Δz * y * x^(y - 1)
@test ∂y ≈ 0
_, ∂x, ∂y = rrule(^, zero(x), y)[2](Δz)
@test ∂x ≈ 0
@test ∂y ≈ 0
end
end
end

@testset "sign" begin
Expand Down