5252 # exponents
5353 @scalar_rule cbrt (x) inv (3 * Ω ^ 2 )
5454 @scalar_rule inv (x) - (Ω ^ 2 )
55- @scalar_rule sqrt (x) inv (2 Ω)
55+ @scalar_rule sqrt (x) inv (2 Ω) # gradient +Inf at x==0
5656 @scalar_rule exp (x) Ω
5757 @scalar_rule exp10 (x) Ω * log (oftype (x, 10 ))
5858 @scalar_rule exp2 (x) Ω * log (oftype (x, 2 ))
137137
138138 # Binary functions
139139
140- # `hypot`
141-
140+ # # `hypot`
142141 function frule (
143142 (_, Δx, Δy),
144143 :: typeof (hypot),
@@ -163,17 +162,52 @@ let
163162 @scalar_rule x + y (true , true )
164163 @scalar_rule x - y (true , - 1 )
165164 @scalar_rule x / y (one (x) / y, - (Ω / y))
166- # log(complex(x)) is required so it gives correct complex answer for x<0
167- @scalar_rule (x ^ y, (
168- ifelse (iszero (x), ifelse (isone (y), one (Ω), zero (Ω)), y * Ω / x),
169- Ω * log (complex (x)),
170- ))
171- # x^y for x < 0 errors when y is not an integer, but then derivative wrt y
172- # is undefined, so we adopt subgradient convention and set derivative to 0.
173- @scalar_rule (x:: Real ^ y:: Real , (
174- ifelse (iszero (x), ifelse (isone (y), one (Ω), zero (Ω)), y * Ω / x),
175- Ω * log (oftype (Ω, ifelse (x ≤ 0 , one (x), x))),
176- ))
165+
166+ # # power
167+ # literal_pow is in base.jl
168+ function frule ((_, Δx, Δp), :: typeof (^ ), x:: Number , p:: Number )
169+ yox = x ^ (p- 1 )
170+ y = yox * x
171+ thelog = if Δp isa AbstractZero
172+ # Then don't waste time computing log
173+ NoTangent ()
174+ elseif x isa Real && p isa Real
175+ # For positive x we'd like a real answer, including any Δp.
176+ # For negative x, this is a DomainError unless isinteger(p)...
177+ # could decide that implues that p is non-differentiable:
178+ # log(ifelse(x<0, one(x), x))
179+
180+ # or we could match what the rrule with ProjectTo gives:
181+ real (log (complex (x)))
182+ #=
183+
184+ julia> frule((0,0,1), ^, -4, 3.0), unthunk.(rrule(^, -4, 3.0)[2](1))
185+ ((-64.0, 0.0), (NoTangent(), 48.0, -88.722839111673))
186+
187+ julia> frule((0,0,1), ^, 4, 3.0), unthunk.(rrule(^, 4, 3.0)[2](1))
188+ ((64.0, 88.722839111673), (NoTangent(), 48.0, 88.722839111673))
189+ =#
190+ else
191+ # This promotion handles e.g. real x & complex p
192+ log (oftype (y, x))
193+ end
194+ return y, muladd (y * thelog, Δp, p * yox * Δx)
195+ end
196+ function rrule (:: typeof (^ ), x:: Number , p:: Number )
197+ yox = x ^ (p- 1 )
198+ project_x, project_p = ProjectTo (x), ProjectTo (p)
199+ @inline function power_pullback (dy)
200+ dx = project_x (conj (p * yox) * dy)
201+ dp = @thunk if x isa Real && p isa Real
202+ project_p (conj (yox * x * log (complex (x))) * dy)
203+ else
204+ project_p (conj (yox * x * log (oftype (yox, x))) * dy)
205+ end
206+ return (NoTangent (), dx, dp)
207+ end
208+ return yox * x, power_pullback
209+ end
210+
177211 @scalar_rule (
178212 rem (x, y),
179213 @setup ((u, nan) = promote (x / y, NaN16 ), isint = isinteger (x / y)),
232266 non_transformed_definitions = intersect (fastable_ast. args, fast_ast. args)
233267 filter! (expr-> ! (expr isa LineNumberNode), non_transformed_definitions)
234268 if ! isempty (non_transformed_definitions)
235- error (
236- " Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n " *
237- join (non_transformed_definitions, " \n " )
269+ @warn (
270+ " Non-FastMath compatible rules defined in fastmath_able.jl." , # \n Definitions:\n" *
271+ # join(non_transformed_definitions, "\n")
272+ non_transformed_definitions
238273 )
239274 end
240275
0 commit comments