-
-
Couldn't load subscription status.
- Fork 216
Description
I encountered an error with the adjoint of getindex of a special
matrix. The issue is that the current getindex adjoints tries to assign to an element in the special matrix, which is not always allowed. It points to a more general problem/opportunity for adjoints of function of special matrices. Using Diagonal as an example, this is the error:
using LinearAlgebra
using Zygote
using Random
rng = MersenneTwister(54754)
n = 3
A = rand(rng,n,n)
y,B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
bA = B_getindex(1)[1] |> display
#ArgumentError: cannot set off-diagonal entry (2, 1) to a nonzero value (1)Two potential possible are:
Solution A : adjoint is a tuple with an unconstrained array
# Copy of getindex(D::Diagonal,...)
using LinearAlgebra:diagzero
@inline function getindex_duplicate(D::Diagonal, i::Int, j::Int)
@boundscheck checkbounds(D, i, j)
if i == j
@inbounds r = D.diag[i]
else
r = diagzero(D, i, j)
end
r
end
yA,∂getindexA = Zygote.pullback(x->getindex_duplicate(x,3,3),Diagonal(A))
D̄A = ∂getindexA(1)[1]
display(D̄A)
# (diag = [0.0, 0.0, 1.0],)Pro: Simple.
Con: The adjoint does not return a Diagonal. This reduces the efficiency of the adjoint code. This can perhaps be a fallback, as we implement solution B.
Solution B : Insert orthogonal projection and return a special matrix
getindexB(D::Diagonal,i::Int,j::Int) = getindex(A,i,j)
Zygote.@adjoint getindexB(D::Diagonal,i::Int,j::Int) = begin
y,∂U = Zygote.pullback(x->getindex_duplicate(x,i,j),D)
y, function(Ȳ)
D̄U = ∂U(Ȳ)[1]
D̄ = Diagonal(D̄U.diag)
return D̄, nothing, nothing
end
end
yB,∂getindexB = Zygote.pullback(x->getindexB(x,3,3),Diagonal(A))
D̄B = ∂getindexB(1)[1]
display(D̄B)
JuliaManifolds/Manifolds.jl#3×3 Diagonal{Float64,Array{Float64,1}}:
# 0.0 ⋅ ⋅
# ⋅ 0.0 ⋅
# ⋅ ⋅ 1.0Pro: Returns a special matrix, so the adjoint code becomes more efficient.
Con: Some work to write the adjoint.
Both solutions provide the same gradients:
ftestA(x) = getindex_duplicate(Diagonal(x.^2),3,3)
ftestB(x) = getindexB(Diagonal(x.^2),3,3)
x = [0.3,-2.5,4.3]
gA = Zygote.gradient(ftestA,x)[1]
gB = Zygote.gradient(ftestB,x)[1]
display(gA==gB)
# trueMathematical background solution B
Suppose we have a special matrix S ∈ 𝕊. For example, for Diagonal, 𝕊 is the space of diagonal matrices. 𝕊 is a subspace of the space of the regular matrices ℤ: 𝕊 ⊂ ℤ. The original function maps from 𝕊 to 𝕐:
Y = f(S) : 𝕊 → 𝕐
For example, for getindex(S,i::Int,j::Int), 𝕐 is scalar, so 𝕐 = ℝ. If we take the adjoint through the the code of f, we currently get a unconstrained matrix S̄ ∈ ℝ, as shown above for getindex. To ensure that S̄ ∈ 𝕊, we insert a orthogonal projection P from ℤ to 𝕊:
S = P(Z) : ℤ → 𝕊
Note that:
- For S ∈ 𝕊, P is the identity: P(S) = S
- Because P is an orthogonal projection, it is self-adjoint: P' = P. Therefore, the adjoint also maps ℤ → 𝕊.
Instead of computing the adjoint for f, we compute it for a function g
g = f∘P
which is:
g' = P∘f'
Because of property 1, g=f for all special matrix S ∈ 𝕊.
This methodology applies to, among others: Diagonal, Symmetric, UpperDiagonal, LowerDiagonal, I(UniformScaling)(?). Often, P is simply the constructor (Diagonal, UpperDiagonal, LowerDiagonal). For Symmetric, it is P(A) = (A+transpose(A))/2
Beyond getindex
In general, this approach could be used for adjoints of functions mapping from the constrained to the unconstrained space, where the adjoint is not guaranteed to map to 𝕊, notably collect and Array(). As with getindex, specializing the current adjoint to Array will probably give you solution A for these. It should be noted that the adjoint for these functions do work now, unlike getindex, so coding these adjoints is less urgent.
Open questions
- Is solution A or solution B preferred? I think B, using A as a fallback
- Do we need to restrict the current adjoint for
getindexto apply only toArrayinstead ofAbstractArray? (I think is was like this in the past). This would have the benefit that solution A is automatically used for special matrices. But I don't know what are the repercussions of this. With the current signature of thegetindexadjoint, I don't see how I can circumvent it to autodiff through the originalgetindexcode in solution B, apart from creating a duplicate, which is of course a undesirable.
If there is an interest in this, I am happy to start a branch to start implementing some of this. I would appreciate guidance and help from others to get it right!