diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md index 8206fbdeb..e7309d2f5 100644 --- a/docs/src/adjoints.md +++ b/docs/src/adjoints.md @@ -56,7 +56,7 @@ julia> mygradient(sin, 0.5) The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user. -If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints". +If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = ȳ*cos(x)'`; the conjugation is usually moot but gives the correct behaviour for complex code, for more details see the section on (complex differentiation)[../complex/]. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints". Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. `y = Complex(a, b)`) will generally take a struct to a tuple (`(ȳ.re, ȳ.im)`). The gradient of a `getindex` `y = x[i...]` is a `setindex!` `x̄[i...] = ȳ`, etc. diff --git a/docs/src/complex.md b/docs/src/complex.md index edde3a3e3..22415afb1 100644 --- a/docs/src/complex.md +++ b/docs/src/complex.md @@ -2,7 +2,9 @@ Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`. -If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. +*A note on notation*: We are using ``\bar c`` to mean the gradient of ``c`` here, like we did before. For the complex conjugate of ``c``, we therefore use the notation ``c^*`` and not the ``c'`` Julia code uses, since that could be confused to mean derivative instead, and we also want to distinguish between the complex conjugate and the conjugate transpose. Note however, that whenever we talk about code snippets, `c'` of course still means conjugate (transpose). + +If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i``. It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. ```julia julia> gradient(c -> abs2(c), 1+2im) @@ -35,7 +37,7 @@ julia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj -0.2 + 0.4im ``` -In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{∂ f}{∂ z'}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice. +In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{\partial f}{\partial z^*}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice. ```julia function jacobi(f, x) @@ -56,3 +58,42 @@ julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im) julia> wirtinger(x -> abs2(x), 1+2im) (1.0 - 2.0im, 1.0 + 2.0im) ``` + +The gradient definition Zygote uses can also be expressed in terms of the [Wirtinger calculus](https://en.wikipedia.org/wiki/Wirtinger_derivatives) using the operators ``\frac{\partial}{\partial z}`` and ``\frac{\partial}{\partial z^*}``. Since ``f(z)`` is always real, we can use that ``f = \mathrm{Re} f`` as a trick to rewrite the gradient of ``f`` in terms of the Wirtinger derivatives. + +```math +f: \mathbb{C} \rightarrow \mathbb{R}, \qquad +f(z) \equiv f(x + iy), \qquad +z \in \mathbb{C}, \ x, y \in \mathbb{R} \\[1.2em] + +\bar f \equiv \frac{\partial f}{\partial x} + i \frac{\partial f}{\partial y} + = 2 \, \frac{\partial(\mathrm{Re}(f))}{\partial z^*} + = \frac{\partial(f + f^*)}{\partial z^*} + = \frac{\partial f}{\partial z^*} + \left(\frac{\partial f}{\partial z}\right)^{\!*} +``` + +Further, we want to study, how these gradients chain together, since the usual chain rule doesn't apply here. We are going to use the relationship we found above, together with the chain rule for Wirtinger derivatives. +Therefore, for the composition of two functions ``f`` and ``w``, one gets the following pullback map, if the inner function ``w`` is holomorphic: + +```math +f: \mathbb{C} \rightarrow \mathbb{R}, \qquad +w: \mathbb{C} \rightarrow \mathbb{C} \\[1.2em] + +\begin{align*} +\overline{f \circ w} + &= \frac{\partial (f \circ w)}{\partial z^*} + \left(\frac{\partial (f \circ w)}{\partial z}\right)^{\!*} + = \frac{\partial f}{\partial w} \frac{\partial w}{\partial z^*} + + \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z^*} + + \left( \frac{\partial f}{\partial w} \frac{\partial w}{\partial z} + + \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z} \right)^{\!*} \\ + &= \left[ \frac{\partial f}{\partial w^*} + \left( \frac{\partial f}{\partial w} \right)^{\!*} \right] + \left( \frac{\partial w}{\partial z} \right)^{\!*} + = \overline{f} \cdot \left( \frac{\partial w}{\partial z} \right)^{\!*} + \qquad \text{if $w(z)$ holomorphic} \Leftrightarrow \frac{\partial w}{\partial z^*} = 0 +\end{align*} +``` + +This nicely explains, why the complex conjugate appears in Zygote's pullback definitions, as pointed out in the [Pullbacks section](../adjoints/#Pullbacks). +If `w` is not holomorphic, the pullback map ``\overline{f} \mapsto \overline{f \circ w}`` is not ``\mathbb{C}``-linear and can therefore not be expressed simply as a multiple of ``\overline{f}``, like in the holomorphic case. + +Attention has to be paid, when comparing Zygote to other AD-tools, since they might use different definitions for complex gradients.