Skip to content

Adjoints for functions of specialized matrices #402

@sdewaele

Description

@sdewaele

I encountered an error with the adjoint of getindex of a special
matrix
. The issue is that the current getindex adjoints tries to assign to an element in the special matrix, which is not always allowed. It points to a more general problem/opportunity for adjoints of function of special matrices. Using Diagonal as an example, this is the error:

using LinearAlgebra
using Zygote
using Random
rng = MersenneTwister(54754)
n = 3
A = rand(rng,n,n)

y,B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
bA = B_getindex(1)[1] |> display
#ArgumentError: cannot set off-diagonal entry (2, 1) to a nonzero value (1)

Two potential possible are:

Solution A : adjoint is a tuple with an unconstrained array

# Copy of getindex(D::Diagonal,...)
using LinearAlgebra:diagzero
@inline function getindex_duplicate(D::Diagonal, i::Int, j::Int)
  @boundscheck checkbounds(D, i, j)
  if i == j
      @inbounds r = D.diag[i]
  else
      r = diagzero(D, i, j)
  end
  r
end

yA,∂getindexA = Zygote.pullback(x->getindex_duplicate(x,3,3),Diagonal(A))
D̄A = ∂getindexA(1)[1]
display(D̄A)
# (diag = [0.0, 0.0, 1.0],)

Pro: Simple.
Con: The adjoint does not return a Diagonal. This reduces the efficiency of the adjoint code. This can perhaps be a fallback, as we implement solution B.

Solution B : Insert orthogonal projection and return a special matrix

getindexB(D::Diagonal,i::Int,j::Int) = getindex(A,i,j)
Zygote.@adjoint getindexB(D::Diagonal,i::Int,j::Int) = begin
  y,∂U = Zygote.pullback(x->getindex_duplicate(x,i,j),D)
  y, function(Ȳ)
    D̄U = ∂U(Ȳ)[1]
    D̄ = Diagonal(D̄U.diag)
    return D̄, nothing, nothing
  end
end

yB,∂getindexB = Zygote.pullback(x->getindexB(x,3,3),Diagonal(A))
D̄B = ∂getindexB(1)[1]
display(D̄B)
JuliaManifolds/Manifolds.jl#3×3 Diagonal{Float64,Array{Float64,1}}:
# 0.0   ⋅    ⋅
#  ⋅   0.0   ⋅
#  ⋅    ⋅   1.0

Pro: Returns a special matrix, so the adjoint code becomes more efficient.
Con: Some work to write the adjoint.

Both solutions provide the same gradients:

ftestA(x) = getindex_duplicate(Diagonal(x.^2),3,3)
ftestB(x) = getindexB(Diagonal(x.^2),3,3)
x = [0.3,-2.5,4.3]
gA = Zygote.gradient(ftestA,x)[1]
gB = Zygote.gradient(ftestB,x)[1]
display(gA==gB)
# true

Mathematical background solution B
Suppose we have a special matrix S ∈ 𝕊. For example, for Diagonal, 𝕊 is the space of diagonal matrices. 𝕊 is a subspace of the space of the regular matrices ℤ: 𝕊 ⊂ ℤ. The original function maps from 𝕊 to 𝕐:
Y = f(S) : 𝕊 → 𝕐
For example, for getindex(S,i::Int,j::Int), 𝕐 is scalar, so 𝕐 = ℝ. If we take the adjoint through the the code of f, we currently get a unconstrained matrix S̄ ∈ ℝ, as shown above for getindex. To ensure that S̄ ∈ 𝕊, we insert a orthogonal projection P from ℤ to 𝕊:
S = P(Z) : ℤ → 𝕊
Note that:

  1. For S ∈ 𝕊, P is the identity: P(S) = S
  2. Because P is an orthogonal projection, it is self-adjoint: P' = P. Therefore, the adjoint also maps ℤ → 𝕊.

Instead of computing the adjoint for f, we compute it for a function g
g = f∘P
which is:
g' = P∘f'
Because of property 1, g=f for all special matrix S ∈ 𝕊.

This methodology applies to, among others: Diagonal, Symmetric, UpperDiagonal, LowerDiagonal, I(UniformScaling)(?). Often, P is simply the constructor (Diagonal, UpperDiagonal, LowerDiagonal). For Symmetric, it is P(A) = (A+transpose(A))/2

Beyond getindex
In general, this approach could be used for adjoints of functions mapping from the constrained to the unconstrained space, where the adjoint is not guaranteed to map to 𝕊, notably collect and Array(). As with getindex, specializing the current adjoint to Array will probably give you solution A for these. It should be noted that the adjoint for these functions do work now, unlike getindex, so coding these adjoints is less urgent.

Open questions

  1. Is solution A or solution B preferred? I think B, using A as a fallback
  2. Do we need to restrict the current adjoint for getindex to apply only to Array instead of AbstractArray? (I think is was like this in the past). This would have the benefit that solution A is automatically used for special matrices. But I don't know what are the repercussions of this. With the current signature of the getindex adjoint, I don't see how I can circumvent it to autodiff through the original getindex code in solution B, apart from creating a duplicate, which is of course a undesirable.

If there is an interest in this, I am happy to start a branch to start implementing some of this. I would appreciate guidance and help from others to get it right!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions