-
-
Notifications
You must be signed in to change notification settings - Fork 216
Explain relationship of Zygote's complex gradients with the Wirtinger calculus #328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,9 @@ | |
|
|
||
| Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`. | ||
|
|
||
| If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. | ||
| *A note on notation*: We are using ``\bar c`` to mean the gradient of ``c`` here, like we did before. For the complex conjugate of ``c``, we therefore use the notation ``c^*`` and not the ``c'`` Julia code uses, since that could be confused to mean derivative instead, and we also want to distinguish between the complex conjugate and the conjugate transpose. Note however, that whenever we talk about code snippets, `c'` of course still means conjugate (transpose). | ||
|
|
||
| If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i``. It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. | ||
|
|
||
| ```julia | ||
| julia> gradient(c -> abs2(c), 1+2im) | ||
|
|
@@ -35,7 +37,7 @@ julia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj | |
| -0.2 + 0.4im | ||
| ``` | ||
|
|
||
| In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{∂ f}{∂ z'}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice. | ||
| In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{\partial f}{\partial z^*}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice. | ||
|
|
||
| ```julia | ||
| function jacobi(f, x) | ||
|
|
@@ -56,3 +58,42 @@ julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im) | |
| julia> wirtinger(x -> abs2(x), 1+2im) | ||
| (1.0 - 2.0im, 1.0 + 2.0im) | ||
| ``` | ||
|
|
||
| The gradient definition Zygote uses can also be expressed in terms of the [Wirtinger calculus](https://en.wikipedia.org/wiki/Wirtinger_derivatives) using the operators ``\frac{\partial}{\partial z}`` and ``\frac{\partial}{\partial z^*}``. Since ``f(z)`` is always real, we can use that ``f = \mathrm{Re} f`` as a trick to rewrite the gradient of ``f`` in terms of the Wirtinger derivatives. | ||
|
|
||
| ```math | ||
| f: \mathbb{C} \rightarrow \mathbb{R}, \qquad | ||
| f(z) \equiv f(x + iy), \qquad | ||
| z \in \mathbb{C}, \ x, y \in \mathbb{R} \\[1.2em] | ||
|
|
||
| \bar f \equiv \frac{\partial f}{\partial x} + i \frac{\partial f}{\partial y} | ||
| = 2 \, \frac{\partial(\mathrm{Re}(f))}{\partial z^*} | ||
| = \frac{\partial(f + f^*)}{\partial z^*} | ||
| = \frac{\partial f}{\partial z^*} + \left(\frac{\partial f}{\partial z}\right)^{\!*} | ||
| ``` | ||
|
|
||
| Further, we want to study, how these gradients chain together, since the usual chain rule doesn't apply here. We are going to use the relationship we found above, together with the chain rule for Wirtinger derivatives. | ||
| Therefore, for the composition of two functions ``f`` and ``w``, one gets the following pullback map, if the inner function ``w`` is holomorphic: | ||
|
|
||
| ```math | ||
| f: \mathbb{C} \rightarrow \mathbb{R}, \qquad | ||
| w: \mathbb{C} \rightarrow \mathbb{C} \\[1.2em] | ||
|
|
||
| \begin{align*} | ||
| \overline{f \circ w} | ||
| &= \frac{\partial (f \circ w)}{\partial z^*} + \left(\frac{\partial (f \circ w)}{\partial z}\right)^{\!*} | ||
| = \frac{\partial f}{\partial w} \frac{\partial w}{\partial z^*} | ||
| + \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z^*} | ||
| + \left( \frac{\partial f}{\partial w} \frac{\partial w}{\partial z} | ||
| + \frac{\partial f}{\partial w^*} \frac{\partial w^*}{\partial z} \right)^{\!*} \\ | ||
| &= \left[ \frac{\partial f}{\partial w^*} + \left( \frac{\partial f}{\partial w} \right)^{\!*} \right] | ||
| \left( \frac{\partial w}{\partial z} \right)^{\!*} | ||
| = \overline{f} \cdot \left( \frac{\partial w}{\partial z} \right)^{\!*} | ||
| \qquad \text{if $w(z)$ holomorphic} \Leftrightarrow \frac{\partial w}{\partial z^*} = 0 | ||
| \end{align*} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I basically see what the above section is saying (though it might be nice to have an implementation in code for the sake of precision/clarity). This section strikes me as confusing though; a big equation dump there and while I'm sure it's all correct, I'm not sure what it's trying to get across.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can try explaining this a bit better in words. I'm trying to give the mathematical reason here, why we need to take the complex conjugate of the derivative in pullbacks. |
||
| ``` | ||
|
|
||
| This nicely explains, why the complex conjugate appears in Zygote's pullback definitions, as pointed out in the [Pullbacks section](../adjoints/#Pullbacks). | ||
| If `w` is not holomorphic, the pullback map ``\overline{f} \mapsto \overline{f \circ w}`` is not ``\mathbb{C}``-linear and can therefore not be expressed simply as a multiple of ``\overline{f}``, like in the holomorphic case. | ||
|
|
||
| Attention has to be paid, when comparing Zygote to other AD-tools, since they might use different definitions for complex gradients. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you mean C->C here? The below references
Re(f)andf*.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is C->R, but since
fis real, I'm usingf = Re(f) = f*here, to express this in terms of the Wirtinger derivatives