-
-
Notifications
You must be signed in to change notification settings - Fork 216
Closed
Description
The issue occurs when the following are all true:
- We have a struct with two differentiable fields
- We need to accumulate the gradients with respect to both fields
- The gradients that need to be accumulated both originate from an rrule
An example is svd factorisation:
using Zygote
using LinearAlgebra
using FiniteDifferences
using Test
# function to test AD against finite differences. Can't wait for https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/114
function test_ad(test_function, Δoutput, inputs...; atol=1e-7, rtol=1e-7)
# Verify that the forwards-pass produces the correct answer.
output, pb = Zygote.pullback(test_function, inputs...)
@test output ≈ test_function(inputs...)
# Compute the adjoints using AD and FiniteDifferences.
dW_ad = pb(Δoutput)
dW_fd = FiniteDifferences.j′vp(central_fdm(5, 1), test_function, Δoutput, inputs...)
# Compare AD and FiniteDifferences results.
@testset "$(typeof(test_function)) argument $n" for n in eachindex(inputs)
@test dW_ad[n] ≈ dW_fd[n] atol=atol rtol=rtol
end
end
function two_svds(X::StridedMatrix{<:Union{Real, Complex}})
return svd(X).U * svd(X).V'
end
function one_svd(X::StridedMatrix{<:Union{Real, Complex}})
F = svd(X)
return F.U * F.V'
end
Δoutput = randn(3,2)
X = randn(3,2)
test_ad(two_svds, Δoutput, X) # works
test_ad(one_svd, Δoutput, X) # fails
Expression: ≈(dW_ad[n], dW_fd[n], atol = atol, rtol = rtol)
Evaluated: [0.10086994903629046 0.017821972969422787; -0.025950275563480844 -0.20012541965315742; 0.1827211721197507 0.16897788019829582] ≈ [-1.5134220464481536 -1.9386355487247564; 0.4546889986237118 0.8662734137317394; 0.9222144395826724 1.4004778288991626] (atol=1.0e-7, rtol=1.0e-7)
I think the reason is that the accumulation of gradients is not defined for literal_getproperty, just literal_getfield
Lines 213 to 226 in 890b6f5
| @adjoint function literal_getfield(x, ::Val{f}) where f | |
| val = getfield(x, f) | |
| function back(Δ) | |
| accum_param(__context__, val, Δ) === nothing && return | |
| if isimmutable(x) | |
| ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing) | |
| else | |
| dx = grad_mut(__context__, x) | |
| dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) | |
| return (dx,nothing) | |
| end | |
| end | |
| unwrap(val), back | |
| end |
Will submit a PR later today
willtebbutt and AlexRobson
Metadata
Metadata
Assignees
Labels
No labels