Skip to content

Conversation

@simeonschaub
Copy link
Member

@simeonschaub simeonschaub commented Sep 13, 2019

I was confused quite a bit at first about how Zygote handles complex gradients, and this wasn't quite as obvious from the documentation. This also came up in the development of ChainRules.jl.

@Krastanov
Copy link

The history of the current phrasing is explained here (as a reference when considering this PR) #29

@simeonschaub
Copy link
Member Author

@MikeInnes Would you mind reviewing this?

The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user.

If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".
If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = ȳ*cos(x)'`; the conjugation is usually moot but gives the correct behaviour for complex code, if `y(x)` is holomorphic. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an intentional change here? If not, best to remove it from the diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think GitHub just can't handle such long lines. See here.

Copy link
Member

@MikeInnes MikeInnes Sep 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original wording was a bit unclear, but this gives the correct result for complex code in general, not just the holomorphic case (it is redundant for real code). The adjoint is what causes us to get back complex sensitivities (otherwise the output would be the conjugate of the sensitivity).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still unclear to me, what you would be conjugating in the non-holomorphic case, since the complex derivative doesn't exist in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Individual pullbacks don't (and can't) know whether the function as a whole is holomorphic, which is a global property, and don't ever see complex derivatives. They only work with sensitivities, as we define them, and taking the adjoint is correct for sensitivities.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably should have clarified: I am only talking about the partial function being holomorphic. So if we're defining the pullback for exp, for example, we can express the partial derivative of exp as a complex derivative. We can't do that for abs2, for example. Is there anywhere, where the term sensitivity is mathematically defined? In my understanding, it is a vector, you pass to a differential form, so it basically specifies the linear combination of partial derivatives. If we're only limiting ourselves to scalars, this would just be the partial derivative of the output with respect to the current function in the chain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's fair, but still an edge case (and even then, we still have the adjoint of the linear map, it's just expressed differently). These specific docs are a reference for adjoints rather than complex AD and I'd rather not overload people with the vagaries of that before they've gotten started with the real part :) So ideally this should just mention briefly that adjoints are generally relevant for complex AD, and have a link to the other docs for more detail.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds like a good idea

Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`.

If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.
If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c^*`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should stick with c' consistently here, because that's what Julia uses (and it's used in a bunch of other places on this page). If we're using c* elsewhere we can change that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just felt a little bit weird to me, to use \overline for derivatives and ' for the complex conjugate in mathematical notation. ^* very commonly refers to just the complex conjugate in physics, so it felt more natural here and less confusing. Of course, I left it in snippets that are actual Julia code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but then we should mention the ' notation as well. Might be better to have a note on it in context, rather than trying to get it into this paragraph.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where exactly are you talking about?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in, we can clarify the notation when it's first used, rather than trying to define everything in a parenthetical.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will do

The gradient definition Zygote uses can also be expressed in terms of the [Wirtinger calculus](https://en.wikipedia.org/wiki/Wirtinger_derivatives) using the operators ``\frac{\partial}{\partial z}`` and ``\frac{\partial}{\partial z^*}``:

```math
f: \mathbb{C} \rightarrow \mathbb{R}, \qquad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you mean C->C here? The below references Re(f) and f*.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is C->R, but since f is real, I'm using f = Re(f) = f* here, to express this in terms of the Wirtinger derivatives

\left( \frac{\partial w}{\partial z} \right)^{\!*}
= \overline{f} \cdot \left( \frac{\partial w}{\partial z} \right)^{\!*}
\qquad \text{if $w(z)$ holomorphic} \Leftrightarrow \frac{\partial w}{\partial z^*} = 0
\end{align*}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically see what the above section is saying (though it might be nice to have an implementation in code for the sake of precision/clarity). This section strikes me as confusing though; a big equation dump there and while I'm sure it's all correct, I'm not sure what it's trying to get across.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try explaining this a bit better in words. I'm trying to give the mathematical reason here, why we need to take the complex conjugate of the derivative in pullbacks.

@simeonschaub
Copy link
Member Author

Sorry, it took me so long to implement these changes. I hope that's better now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants