Skip to content

NaN gradients for sqrt #1101

@JordiBolibar

Description

@JordiBolibar

After a long time hunting a bug with @facusapienza21, we have realized that Zygote fails to provide a gradient for the basic sqrt function. This has been discussed at length in this Discourse thread.

Here's a MWE to reproduce the issue:

using Zygote 
using Flux

A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]

function loss(θ)
    A = A₀.^θ
    A = sqrt.(A)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end

θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ) 

For this last case, the value of back_θ(1.0) is NaN. However, if we avoid the use of sqrt() by defining the loss function as

function loss(θ)
    A = A₀.^/2)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end

then Zygote provides the right gradient.

According to @mcabbott, "the reason we get NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x at 4, which is zero. Whereas with the 0^(x/2) version, the slope is simply zero".

Being such a basic function, this bug can potentially impact a large number of users.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions