-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
JuliaDiff/ChainRules.jl
#599Description
After a long time hunting a bug with @facusapienza21, we have realized that Zygote fails to provide a gradient for the basic sqrt
function. This has been discussed at length in this Discourse thread.
Here's a MWE to reproduce the issue:
using Zygote
using Flux
A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]
function loss(θ)
A = A₀.^θ
A = sqrt.(A)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end
θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ)
For this last case, the value of back_θ(1.0) is NaN
. However, if we avoid the use of sqrt()
by defining the loss function as
function loss(θ)
A = A₀.^(θ/2)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end
then Zygote provides the right gradient.
According to @mcabbott, "the reason we get NaN
is that the slope of sqrt
at zero is infinite. That infinity multiplies the slope of 0^x
at 4, which is zero. Whereas with the 0^(x/2)
version, the slope is simply zero".
Being such a basic function, this bug can potentially impact a large number of users.
Metadata
Metadata
Assignees
Labels
No labels