-
Notifications
You must be signed in to change notification settings - Fork 93
Description
Various discussions have been had in various places about the correct kinds of types to implement rrule
s for, but we've not discussed this in a central location. This problem probably occurs for some frule
s, but doesn't seem as prevalent as in the rrule
case.
Problem Statement
The general theme of the problem is whether or not to view certain types as being "embedded" inside others or not, for the purpose of computing derivatives. For example, is a Diagonal
matrix one that just happens to be diagonal and is equally-well thought of as a Matrix
, or is it really it's own thing? Similarly, is an Integer
just a Real
, or is it something else?
I will first attempt to demonstrate the implications of each choice with a couple of representative examples.
Diagonal
matrix multiplication
Consider the following rrule
implementation:
function rrule(::typeof(*), X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real})
function mul_AbstractMatrix_pullback(ΔΩ::AbstractMatrix{<:Real})
return ΔΩ * Y', X' * ΔΩ
end
return X * Y, mul_AbstractMatrix_pullback
end
If X
and Y
are Matrix{Float64}
s for example, then this is a perfectly reasonable implementation -- ΔΩ
should also be a Matrix{Float64}
if whoever is calling this rule is calling it correctly.
Things break down if X
is a Diagonal{Float64}
. The forwards-pass is completely fine, as is the computation of the cotangent for Y
, X' * ΔΩ
. However, the complexity of the cotangent representation / computation for X
is now very concerning -- ΔΩ * Y'
produces a Matrix{Float64}
. Such a matrix is specified by O(N^2)
numbers rather than O(N)
required for X
, and requires O(N^3)
-time to compute, as opposed to the forwards-pass complexity O(N^2)
. This breaks the promise that the forwards- and reverse-pass time- and memory-complexity of reverse-mode AD should be the same, in essence rendering the structure in a Diagonal
matrix if used in an AD system where this is the only rule for multiplication of matrices.
Moreover, what does it mean to consider a non-zero "gradient" w.r.t. the off-diagonal elements of a Diagonal
matrix? If you take the view that it's just another Matrix
, then there's no issue. The other point of view is that there's no meaningful way to define a non-zero gradient w.r.t. the off-diagonal elements of a Diagonal
matrix without considering matrices outside of the space of Diagonal
matrices -- intuitively, if you "perturb" an off-diagonal element, you no longer have a Diagonal
matrix. Consequently, a Matrix
isn' an appropriate type to represent the gradient of a Diagonal
. If someone has a way to properly formalise this argument, please consider providing it.
It seems that the first view necessitates giving up on the complexity guarantees that reverse-mode AD provides, while the second view necessitates giving up on implementing rrule
s for abstract types (roughly speaking). The former is (in my opinion) a complete show-stopper, while the latter is something we can in principle live with.
Of course you could add a specialised implementation for Diagonal
matrices. However, I would suggest that you ponder your favourite structured matrix type and try to figure out whether it has similar issues. Most of the structured matrix types that I have encountered suffer from precisely this issue with many operations defined on them -- only those that are "dense" in the same way that a Matrix
is do not. Consequently, it is not the case that we'll eventually reach a point where we've implemented enough specialised rules -- people will keep creating new subtypes of AbstractMatrix
and we'll be stuck in a cycle of forever adding new rules. This seems sub-optimal given that a reasonable AD aught to be able to derive them for us. Moreover, whenever someone who isn't overly familiar with AD implements a new AbstractMatrix
, they would need to implement a host of new rrule
s, which also seems like a show-stopper.
Number
multiplication
Now consider implementing an rrule
for *
between two numbers. Clearly
function rrule(::typeof(*), x::Float64, y::Float64)
function mul_Float64_pullback(ΔΩ::Float64)
return ΔΩ * y, ΔΩ * x
end
return x * y, mul_Float64_pullback
end
is a correctly implemented rule. Float64
is concrete, so there's no chance that someone will subtype it and require a different implementation for their subtype. In this sense, we can guarantee that this rule is correct for any of the inputs that it admits (up to finite-precision arithmetic issues).
What would happen if you implemented this rule instead for Real
s? Suppose someone provided an Integer
argument for y
, then its cotangent will be probably be a Float64
. While this doesn't provide the same complexity issues as the Diagonal
example above, treating the Integer
s as being embedded in the Real
s can cause some headaches, such as the one's that @sethaxen addressed in #224 -- where it becomes very important to distinguish between Integer
and Real
exponents for the sake of performance and correctness. Since it is presumably not acceptable to sometimes treat the Integer
s as special cases of the Real
s and some times not, it follows that *
should not be implemented between Real
s, but between AbstractFloat
s if we're actually trying to be consistent.
Will this cause issues for users? Seth pointed out that the only situation in which this is likely to be problematic is the case in which an Integer
argument is provided to an AD tool. This doesn't seem like a show stopper.
What Gives?
The issue seems to stem from implementing rules for types that you don't know about. For example, you can't know whether the *
implementation above is suitable for all concrete matrices that sub-type AbstractMatrix
, even if they otherwise seem like perfectly reasonable matrix types.
How does this square with the typical uses of multiple dispatch within Julia? One common line-of-thought is roughly "multiple dispatch seems to work just fine in the rest of Julia, so why can't we just implement things as they're needed in this case?". The answer seems to be that
- while generic fallbacks in Julia can be slow, they're at least correct if your type correctly implements whichever interface it is meant to e.g. the multiplication of two
AbstractMatrix
s will generally give the correct answer. This simply doesn't hold if you're inclined to take the view that aMatrix
isn't a suitable type to represent the gradient w.r.t. aDiagonal
, and - general Julia code doesn't provide you with asymptotic complexity guarantees in it's fallbacks -- if you encounter a fallback and it has far worse complexity than you know is achievable for a particular operation on your type, you're unlikely to be too annoyed -- how could you possibly expect Julia to magically e.g. exploit some special structure in your matrix type unless you tell it how to? This not the case with AD because a) reverse-mode AD provides complexity guarantees and b) the code to exploit structure was written to implement the forwards-pass, so a reasonable AD aught to be able to exploit it to construct an efficient pullback. It's very frustrating when this has been prevented from happening through the implementation of a rule that "over-reaches" and prevents an AD tool from doing what it's meant to be doing.
What To Do?
The obvious thing to do is to advise that rules are only implemented when you know they'll be correct, which means you have to restrict yourself to implementing rules for concrete types, and collections thereof. Unfortunately, doing this will almost certainly break e.g. Zygote because it relies heavily on very broadly-defined rules that in general exhibit all of the problems discussed above. Maybe some careful middle ground needs to be found, or maybe Zygote needs to press forward with it's handling of mutation so that it can actually handle the things that it can deal with such changes.
Related Work
Writing this issue was motivated by #226, where @nickrobinson251 pointed out that we should have an issue about this, and that we should consider adding something to the docs.
@sethaxen @oxinabox @MasonProtter any thoughts?
@DhairyaLGandhi @CarloLucibello this is very relevant for Zygote, so could do with your input.