Skip to content

accumulation of gradients #922

@mzgubic

Description

@mzgubic

The issue occurs when the following are all true:

  • We have a struct with two differentiable fields
  • We need to accumulate the gradients with respect to both fields
  • The gradients that need to be accumulated both originate from an rrule

An example is svd factorisation:

using Zygote
using LinearAlgebra
using FiniteDifferences
using Test

# function to test AD against finite differences. Can't wait for https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/114
function test_ad(test_function, Δoutput, inputs...; atol=1e-7, rtol=1e-7)
    # Verify that the forwards-pass produces the correct answer.
    output, pb = Zygote.pullback(test_function, inputs...)
    @test output  test_function(inputs...)

    # Compute the adjoints using AD and FiniteDifferences.
    dW_ad = pb(Δoutput)
    dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)

    # Compare AD and FiniteDifferences results.
    @testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
        @test dW_ad[n]  dW_fd[n] atol=atol rtol=rtol
    end
end

function two_svds(X::StridedMatrix{<:Union{Real, Complex}})
    return svd(X).U * svd(X).V'
end

function one_svd(X::StridedMatrix{<:Union{Real, Complex}})
    F = svd(X)
    return F.U * F.V'
end

Δoutput = randn(3,2)
X = randn(3,2)

test_ad(two_svds, Δoutput, X) # works

test_ad(one_svd, Δoutput, X) # fails

  Expression: (dW_ad[n], dW_fd[n], atol = atol, rtol = rtol)
   Evaluated: [0.10086994903629046 0.017821972969422787; -0.025950275563480844 -0.20012541965315742; 0.1827211721197507 0.16897788019829582]  [-1.5134220464481536 -1.9386355487247564; 0.4546889986237118 0.8662734137317394; 0.9222144395826724 1.4004778288991626] (atol=1.0e-7, rtol=1.0e-7)

I think the reason is that the accumulation of gradients is not defined for literal_getproperty, just literal_getfield

Zygote.jl/src/lib/lib.jl

Lines 213 to 226 in 890b6f5

@adjoint function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
return (dx,nothing)
end
end
unwrap(val), back
end

Will submit a PR later today

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions