-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
JuliaDiff/ChainRules.jl
#599Description
After a long time hunting a bug with @facusapienza21, we have realized that Zygote fails to provide a gradient for the basic sqrt function. This has been discussed at length in this Discourse thread.
Here's a MWE to reproduce the issue:
using Zygote
using Flux
A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]
function loss(θ)
A = A₀.^θ
A = sqrt.(A)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end
θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ) For this last case, the value of back_θ(1.0) is NaN. However, if we avoid the use of sqrt() by defining the loss function as
function loss(θ)
A = A₀.^(θ/2)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
endthen Zygote provides the right gradient.
According to @mcabbott, "the reason we get NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x at 4, which is zero. Whereas with the 0^(x/2) version, the slope is simply zero".
Being such a basic function, this bug can potentially impact a large number of users.
Metadata
Metadata
Assignees
Labels
No labels