@@ -137,7 +137,7 @@ const FASTABLE_AST = quote
137137 test_rrule (f, 10 rand (T), rand (T))
138138 end
139139
140- @testset " $f (x::$T , y::$T ) type check" for f in (/ , + , - ,\ , hypot, ^ ), T in (Float32, Float64)
140+ @testset " $f (x::$T , y::$T ) type check" for f in (/ , + , - ,\ , hypot), T in (Float32, Float64)
141141 x, Δx, x̄ = 10 rand (T, 3 )
142142 y, Δy, ȳ = rand (T, 3 )
143143 @assert T == typeof (f (x, y))
@@ -159,28 +159,78 @@ const FASTABLE_AST = quote
159159 end
160160 end
161161
162- @testset " ^(x::$T , n::$T )" for T in (Float64, ComplexF64)
163- # for real x and n, x must be >0
164- test_frule (^ , rand (T) + 3 , rand (T) + 3 )
165- test_rrule (^ , rand (T) + 3 , rand (T) + 3 )
166-
167- T <: Real && @testset " discontinuity for ^(x::Real, n::Int) when x ≤ 0" begin
168- # finite differences doesn't work for x < 0, so we check manually
169- x = - rand (T) .- 3
170- y = 3
171- Δx = randn (T)
172- Δy = randn (T)
173- Δz = randn (T)
174-
175- @test frule ((ZeroTangent (), Δx, Δy), ^ , x, y)[2 ] ≈ Δx * y * x^ (y - 1 )
176- @test frule ((ZeroTangent (), Δx, Δy), ^ , zero (x), y)[2 ] ≈ 0
177- _, ∂x, ∂y = rrule (^ , x, y)[2 ](Δz)
178- @test ∂x ≈ Δz * y * x^ (y - 1 )
179- @test ∂y ≈ 0
180- _, ∂x, ∂y = rrule (^ , zero (x), y)[2 ](Δz)
181- @test ∂x ≈ 0
182- @test ∂y ≈ 0
162+ @testset " ^(x::$T , p::$S )" for T in (Float64, ComplexF64), S in (Float64, ComplexF64)
163+ test_frule (^ , rand (T) + 3 , rand (S) + 3 )
164+ test_rrule (^ , rand (T) + 3 , rand (S) + 3 )
165+
166+ # When both x & p are Real, and !(isinteger(p)),
167+ # then x must be positive to avoid a DomainError
168+ T <: Real && S <: Real && continue
169+ # In other cases, we can test values near zero:
170+
171+ test_frule (^ , randn (T), rand (S))
172+ test_rrule (^ , rand (T), rand (S))
173+ end
174+
175+ # Tests for power functions, at values near to zero.
176+ POWERGRADS = [ # (x,p) => (dx,dp)
177+ # Some regular points, as sanity checks:
178+ (1.0 , 2 ) => (2.0 , 0.0 ),
179+ (2.0 , 2 ) => (4.0 , 2.772588722239781 ),
180+ # At x=0, gradients for x seem clear,
181+ # for p less certain what's best.
182+ (0.0 , 2 ) => (0.0 , 0.0 ),
183+ (- 0.0 , 2 ) => (0.0 , 0.0 ), # probably (-0.0, 0.0) would be ideal
184+ (0.0 , 1 ) => (1.0 , 0.0 ),
185+ (- 0.0 , 1 ) => (1.0 , 0.0 ),
186+ (0.0 , 0 ) => (0.0 , NaN ),
187+ (- 0.0 , 0 ) => (0.0 , NaN ),
188+ (0.0 , - 1 ) => (- Inf , NaN ),
189+ (- 0.0 , - 1 ) => (- Inf , NaN ),
190+ (0.0 , - 2 ) => (- Inf , NaN ),
191+ (- 0.0 , - 2 ) => (Inf , NaN ),
192+ # Integer x & p, check no InexactErrors
193+ (0 , 2 ) => (0.0 , 0.0 ),
194+ (0 , 1 ) => (1.0 , 0.0 ),
195+ (0 , 0 ) => (0.0 , NaN ),
196+ (0 , - 1 ) => (- Inf , NaN ),
197+ (0 , - 2 ) => (- Inf , NaN ),
198+ # Non-integer powers:
199+ (0.0 , 0.5 ) => (Inf , 0.0 ),
200+ (0.0 , 3.5 ) => (0.0 , 0.0 ),
201+ (0.0 , - 1.5 ) => (- Inf , NaN ),
202+ ]
203+
204+ @testset " $x ^ $p " for ((x,p), (∂x, ∂p)) in POWERGRADS
205+ if x isa Integer && p isa Integer && p < 0
206+ @test_throws DomainError x^ p
207+ continue
183208 end
209+ y = x^ p
210+
211+ # Forward
212+ y_fwd = frule ((1 ,1 ,1 ), ^ , x, p)[1 ]
213+ @test isequal (y, y_fwd)
214+
215+ ∂x_fwd = frule ((0 ,1 ,0 ), ^ , x, p)[2 ]
216+ ∂p_fwd = frule ((0 ,0 ,1 ), ^ , x, p)[2 ]
217+ @test isequal (∂x, ∂x_fwd)
218+ if x=== 0.0 && p=== 0.5
219+ @test_broken isequal (∂p, ∂p_fwd)
220+ else
221+ @test isequal (∂p, ∂p_fwd)
222+ end
223+
224+ ∂x_fwd = frule ((0 ,1 ,ZeroTangent ()), ^ , x, p)[2 ] # easier, strong zero
225+ @test isequal (∂x, ∂x_fwd)
226+
227+ # Reverse
228+ y_rev = rrule (^ , x, p)[1 ]
229+ @test isequal (y, y_rev)
230+
231+ ∂x_rev, ∂p_rev = unthunk .(rrule (^ , x, p)[2 ](1 ))[2 : 3 ]
232+ @test isequal (∂x, ∂x_rev)
233+ @test isequal (∂p, ∂p_rev)
184234 end
185235 end
186236
0 commit comments