Skip to content

Type Constraints #232

@willtebbutt

Description

@willtebbutt

Various discussions have been had in various places about the correct kinds of types to implement rrules for, but we've not discussed this in a central location. This problem probably occurs for some frules, but doesn't seem as prevalent as in the rrule case.

Problem Statement

The general theme of the problem is whether or not to view certain types as being "embedded" inside others or not, for the purpose of computing derivatives. For example, is a Diagonal matrix one that just happens to be diagonal and is equally-well thought of as a Matrix, or is it really it's own thing? Similarly, is an Integer just a Real, or is it something else?

I will first attempt to demonstrate the implications of each choice with a couple of representative examples.

Diagonal matrix multiplication

Consider the following rrule implementation:

function rrule(::typeof(*), X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real})
    function mul_AbstractMatrix_pullback(ΔΩ::AbstractMatrix{<:Real})
        return ΔΩ * Y', X' * ΔΩ
    end
    return X * Y, mul_AbstractMatrix_pullback
end

If X and Y are Matrix{Float64}s for example, then this is a perfectly reasonable implementation -- ΔΩ should also be a Matrix{Float64} if whoever is calling this rule is calling it correctly.

Things break down if X is a Diagonal{Float64}. The forwards-pass is completely fine, as is the computation of the cotangent for Y, X' * ΔΩ. However, the complexity of the cotangent representation / computation for X is now very concerning -- ΔΩ * Y' produces a Matrix{Float64}. Such a matrix is specified by O(N^2) numbers rather than O(N) required for X, and requires O(N^3)-time to compute, as opposed to the forwards-pass complexity O(N^2). This breaks the promise that the forwards- and reverse-pass time- and memory-complexity of reverse-mode AD should be the same, in essence rendering the structure in a Diagonal matrix if used in an AD system where this is the only rule for multiplication of matrices.

Moreover, what does it mean to consider a non-zero "gradient" w.r.t. the off-diagonal elements of a Diagonal matrix? If you take the view that it's just another Matrix, then there's no issue. The other point of view is that there's no meaningful way to define a non-zero gradient w.r.t. the off-diagonal elements of a Diagonal matrix without considering matrices outside of the space of Diagonal matrices -- intuitively, if you "perturb" an off-diagonal element, you no longer have a Diagonal matrix. Consequently, a Matrix isn' an appropriate type to represent the gradient of a Diagonal. If someone has a way to properly formalise this argument, please consider providing it.

It seems that the first view necessitates giving up on the complexity guarantees that reverse-mode AD provides, while the second view necessitates giving up on implementing rrules for abstract types (roughly speaking). The former is (in my opinion) a complete show-stopper, while the latter is something we can in principle live with.

Of course you could add a specialised implementation for Diagonal matrices. However, I would suggest that you ponder your favourite structured matrix type and try to figure out whether it has similar issues. Most of the structured matrix types that I have encountered suffer from precisely this issue with many operations defined on them -- only those that are "dense" in the same way that a Matrix is do not. Consequently, it is not the case that we'll eventually reach a point where we've implemented enough specialised rules -- people will keep creating new subtypes of AbstractMatrix and we'll be stuck in a cycle of forever adding new rules. This seems sub-optimal given that a reasonable AD aught to be able to derive them for us. Moreover, whenever someone who isn't overly familiar with AD implements a new AbstractMatrix, they would need to implement a host of new rrules, which also seems like a show-stopper.

Number multiplication

Now consider implementing an rrule for * between two numbers. Clearly

function rrule(::typeof(*), x::Float64, y::Float64)
    function mul_Float64_pullback(ΔΩ::Float64)
        return ΔΩ * y, ΔΩ * x
    end
    return x * y, mul_Float64_pullback
end

is a correctly implemented rule. Float64 is concrete, so there's no chance that someone will subtype it and require a different implementation for their subtype. In this sense, we can guarantee that this rule is correct for any of the inputs that it admits (up to finite-precision arithmetic issues).

What would happen if you implemented this rule instead for Reals? Suppose someone provided an Integer argument for y, then its cotangent will be probably be a Float64. While this doesn't provide the same complexity issues as the Diagonal example above, treating the Integers as being embedded in the Reals can cause some headaches, such as the one's that @sethaxen addressed in #224 -- where it becomes very important to distinguish between Integer and Real exponents for the sake of performance and correctness. Since it is presumably not acceptable to sometimes treat the Integers as special cases of the Reals and some times not, it follows that * should not be implemented between Reals, but between AbstractFloats if we're actually trying to be consistent.

Will this cause issues for users? Seth pointed out that the only situation in which this is likely to be problematic is the case in which an Integer argument is provided to an AD tool. This doesn't seem like a show stopper.

What Gives?

The issue seems to stem from implementing rules for types that you don't know about. For example, you can't know whether the * implementation above is suitable for all concrete matrices that sub-type AbstractMatrix, even if they otherwise seem like perfectly reasonable matrix types.

How does this square with the typical uses of multiple dispatch within Julia? One common line-of-thought is roughly "multiple dispatch seems to work just fine in the rest of Julia, so why can't we just implement things as they're needed in this case?". The answer seems to be that

  1. while generic fallbacks in Julia can be slow, they're at least correct if your type correctly implements whichever interface it is meant to e.g. the multiplication of two AbstractMatrixs will generally give the correct answer. This simply doesn't hold if you're inclined to take the view that a Matrix isn't a suitable type to represent the gradient w.r.t. a Diagonal, and
  2. general Julia code doesn't provide you with asymptotic complexity guarantees in it's fallbacks -- if you encounter a fallback and it has far worse complexity than you know is achievable for a particular operation on your type, you're unlikely to be too annoyed -- how could you possibly expect Julia to magically e.g. exploit some special structure in your matrix type unless you tell it how to? This not the case with AD because a) reverse-mode AD provides complexity guarantees and b) the code to exploit structure was written to implement the forwards-pass, so a reasonable AD aught to be able to exploit it to construct an efficient pullback. It's very frustrating when this has been prevented from happening through the implementation of a rule that "over-reaches" and prevents an AD tool from doing what it's meant to be doing.

What To Do?

The obvious thing to do is to advise that rules are only implemented when you know they'll be correct, which means you have to restrict yourself to implementing rules for concrete types, and collections thereof. Unfortunately, doing this will almost certainly break e.g. Zygote because it relies heavily on very broadly-defined rules that in general exhibit all of the problems discussed above. Maybe some careful middle ground needs to be found, or maybe Zygote needs to press forward with it's handling of mutation so that it can actually handle the things that it can deal with such changes.

Related Work

#81

Writing this issue was motivated by #226, where @nickrobinson251 pointed out that we should have an issue about this, and that we should consider adding something to the docs.

@sethaxen @oxinabox @MasonProtter any thoughts?

@DhairyaLGandhi @CarloLucibello this is very relevant for Zygote, so could do with your input.

Metadata

Metadata

Assignees

No one assigned

    Labels

    designRequires some design before changes are madetype constraintsPotentially raises a question about how tightly to constrain argument types for a rule. See #232

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions