Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ julia> mygradient(sin, 0.5)

The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user.

If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".
If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = ȳ*cos(x)'`; the conjugation is usually moot but gives the correct behaviour for complex code, for more details see the section on (complex differentiation)[../complex/]. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".

Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. `y = Complex(a, b)`) will generally take a struct to a tuple (`(ȳ.re, ȳ.im)`). The gradient of a `getindex` `y = x[i...]` is a `setindex!` `x̄[i...] = ȳ`, etc.

Expand Down
45 changes: 43 additions & 2 deletions docs/src/complex.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`.

If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.
*A note on notation*: We are using ``\bar c`` to mean the gradient of ``c`` here, like we did before. For the complex conjugate of ``c``, we therefore use the notation ``c^*`` and not the ``c'`` Julia code uses, since that could be confused to mean derivative instead, and we also want to distinguish between the complex conjugate and the conjugate transpose. Note however, that whenever we talk about code snippets, `c'` of course still means conjugate (transpose).

If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i``. It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.

```julia
julia> gradient(c -> abs2(c), 1+2im)
Expand Down Expand Up @@ -35,7 +37,7 @@ julia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj
-0.2 + 0.4im
```

In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{ f}{∂ z'}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice.
In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{\partial f}{\partial z^*}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice.

```julia
function jacobi(f, x)
Expand All @@ -56,3 +58,42 @@ julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im)
julia> wirtinger(x -> abs2(x), 1+2im)
(1.0 - 2.0im, 1.0 + 2.0im)
```

The gradient definition Zygote uses can also be expressed in terms of the [Wirtinger calculus](https://en.wikipedia.org/wiki/Wirtinger_derivatives) using the operators ``\frac{\partial}{\partial z}`` and ``\frac{\partial}{\partial z^*}``. Since ``f(z)`` is always real, we can use that ``f = \mathrm{Re} f`` as a trick to rewrite the gradient of ``f`` in terms of the Wirtinger derivatives.

```math
f: \mathbb{C} \rightarrow \mathbb{R}, \qquad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you mean C->C here? The below references Re(f) and f*.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is C->R, but since f is real, I'm using f = Re(f) = f* here, to express this in terms of the Wirtinger derivatives

f(z) \equiv f(x + iy), \qquad
z \in \mathbb{C}, \ x, y \in \mathbb{R} \\[1.2em]

\bar f \equiv \frac{\partial f}{\partial x} + i \frac{\partial f}{\partial y}
= 2 \, \frac{\partial(\mathrm{Re}(f))}{\partial z^*}
= \frac{\partial(f + f^*)}{\partial z^*}
= \frac{\partial f}{\partial z^*} + \left(\frac{\partial f}{\partial z}\right)^{\!*}
```

Further, we want to study, how these gradients chain together, since the usual chain rule doesn't apply here. We are going to use the relationship we found above, together with the chain rule for Wirtinger derivatives.
Therefore, for the composition of two functions ``f`` and ``w``, one gets the following pullback map, if the inner function ``w`` is holomorphic:

```math
f: \mathbb{C} \rightarrow \mathbb{R}, \qquad
w: \mathbb{C} \rightarrow \mathbb{C} \\[1.2em]

\begin{align*}
\overline{f \circ w}
&= \frac{\partial (f \circ w)}{\partial z^*} + \left(\frac{\partial (f \circ w)}{\partial z}\right)^{\!*}
= \frac{\partial f}{\partial w} \frac{\partial w}{\partial z^*}
+ \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z^*}
+ \left( \frac{\partial f}{\partial w} \frac{\partial w}{\partial z}
+ \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z} \right)^{\!*} \\
&= \left[ \frac{\partial f}{\partial w^*} + \left( \frac{\partial f}{\partial w} \right)^{\!*} \right]
\left( \frac{\partial w}{\partial z} \right)^{\!*}
= \overline{f} \cdot \left( \frac{\partial w}{\partial z} \right)^{\!*}
\qquad \text{if $w(z)$ holomorphic} \Leftrightarrow \frac{\partial w}{\partial z^*} = 0
\end{align*}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically see what the above section is saying (though it might be nice to have an implementation in code for the sake of precision/clarity). This section strikes me as confusing though; a big equation dump there and while I'm sure it's all correct, I'm not sure what it's trying to get across.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try explaining this a bit better in words. I'm trying to give the mathematical reason here, why we need to take the complex conjugate of the derivative in pullbacks.

```

This nicely explains, why the complex conjugate appears in Zygote's pullback definitions, as pointed out in the [Pullbacks section](../adjoints/#Pullbacks).
If `w` is not holomorphic, the pullback map ``\overline{f} \mapsto \overline{f \circ w}`` is not ``\mathbb{C}``-linear and can therefore not be expressed simply as a multiple of ``\overline{f}``, like in the holomorphic case.

Attention has to be paid, when comparing Zygote to other AD-tools, since they might use different definitions for complex gradients.