-
Notifications
You must be signed in to change notification settings - Fork 37
Support "diagonal" primitives with no/slow JVP batch rule #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
merge main into fork
…ntire Jacobian matrix
|
I think I prefer this implementation (latest commit) with |
|
I finally realised why Would it be helpful to re-write the PR message with updated benchmarks and prose in light of this? |
|
That reasoning was completely wrong again, the derivative is not evaluated at the tangent vectors just multiplied by them. I think the reduced FLOPs due to matrix multiplication is only more noticeable on the example above because However, it is easy to fool def myfunc(x):
halfway = len(x) // 2
return jnp.concatenate([jnp.sin(x[:halfway]), jnp.cos(x[halfway:])])Here we seen an O(n) speed exceeding 1E4 for array sizes of 4E4 which is huge. In general I see no significant adverse impact of this and some very pronounced positive impacts in realistic use cases. |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I really like seeing a colouring approach like this!
I will note however that lx.diagonal is documented as Extracts the diagonal from a linear operator, and returns a vector, which is meant to include extracting the diagonal from nondiagonal operators. I think the implementations you have here should check with is_diagonal to determine whether to dispatch to the new or the old implementation.
|
|
||
|
|
||
| @diagonal.register(MatrixLinearOperator) | ||
| @diagonal.register(PyTreeLinearOperator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note, I think we could make the PyTreeLinearOperator case a little more efficient by calling jnp.diag on each 'diagonal leaf', and concatenate those together. (You don't have to do that here, just spotting it.)
lineax/_operator.py
Outdated
| elif operator.jac == "bwd": | ||
| fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args)) | ||
| _, vjp_fun = jax.vjp(fn, operator.x) | ||
| diag_as_pytree = vjp_fun(unravel(basis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will fail for operators with different input and output structures. (They might still be mathematically square, just split up into a pytree in different ways / with different dtypes.) This needs to be a basis formed from operator.out_structure().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added a test for mismatched input and output structure and looks like these are not allowed to be diagonal in the current version of lineax. Shall we leave relaxing that assumption and fixing this for future work? This implementation should work fine as long as the assumption holds.
FAILED tests/test_operator.py::test_is_symmetric[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_symmetric[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
FAILED tests/test_operator.py::test_is_diagonal[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_diagonal[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, these are PyTreeLinearOperators but I need to address for JacobianLinearOperators
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually got exactly the same error with JacobianLinearOperator, if this is not expected I can push the test I added to test/helpers to investigate further
@_operators_append
def make_nontrivial_jac_operator(getkey, matrix, tags):
# makes a Jacobian linear operator from matrix with
# input structure {"array", (in_size -1,), "scalar": ()}
# output structure ((), (out_size - 1,))
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
diff = matrix - jac
def fn(x, args):
x_flat = jnp.concatenate([x["array"], x["scalar"][jnp.newaxis]])
y_flat = a + (b + diff) @ x_flat + c @ x_flat**2
y = [y_flat[0], y_flat[1:]]
return y
return lx.JacobianLinearOperator(fn, {"array": x[:-1], "scalar": x[-1]}, None, tags)|
Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think |
However they're actually a standalone public API themselves, which could be used independent of the solvers. :) |
|
Sorry for abandoning this for so long as the day job took over. I returned as it was found to provide orders of magnitude impact on a problem I was working (root finding over multiple interpolations) even for small array sizes (200–2000). I have addressed the main point of retaining extraction of the diagonal when the diagonal tag is missing. However, your other two comments seem at odds with each other: either we ensure input/output structures match (which is actually the case, see above) enabling us to extract diagonal leafs for |
|
Sorry for the long delay getting back to you, some personal life things took over for a while. So, now to actually answer your question: good point. I imagine we could probably do the 'diagonal leaf' approach when the structures match, and go for the more expensive approach when they don't? |
|
No worries, hope you're managing alright. Sorry to miss you at DiffSys, but looking forward to catching up with Johanna! Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated? Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done? |
Yup, I think that'd be reasonable! Let me know when this PR is ready and we'll get this merged :) |



BREAKING CHANGE: Diagonals are no longer "extracted" from operators but rely on the promise of the
diagonaltag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.Preface
At present, both
JacobianLinearOperatorandFunctionLinearOperatorrequire full materialisation even if provided with adiagonaltag. This seems self-evidently expensive (in practice it certainly can be but more often is not, see below) and requires the underlying function (which could potentially be a custom primitive) to have a batching rule. As it is currently the case that tags are considered to be a "promise" and are unchecked with no guarantee of behaviour, there are some shortcuts we can take.Changes made
The proposal here is to use the observation that the diagonal of a matrix can be obtained by pre/post-multiplying it by a unit vector and thereby re-write the single-dispatch
diagonalmethod forJacobianLinearOperatorandFunctionLinearOperatorso that theas_matrix()method is not required. ForJacobianLinearOperatoreitherjax.jvporjax.vjpwill be called depending on thejackeyword (forward-mode should always be more efficient but meeting the user's expectation will avoid issues if forward-mode is not supported such as when using acustom_vjpis used). ForFunctionLinearOperator, we can just useself.mv.However, if the matrix is not actually diagonal this identity will not hold and results may be unexpected due to contributions from off-diagonals.
Alternative design options
If you are not a fan of the breaking change and the slight performance hit for some array sizes (and devices?... might be worthwhile trying out on other machines) discussed below, would it be acceptable to introduce a
diagonal_nobatchtag that implements this as an alternative? Not sure it would work out the box as sometimes the diagonal tag is added automatically for scalar functions, but we could customise the standarddiagonaltag for scalar only if that is the only problem. Of course, I understand if you don't want to support un/slow-batched primitives given the trade-offs, in which case I'm happy to just maintain my own fork oflineaxfor my use case.We could also "check" the tags by casting non-zero (or non-small) values as nans to enable error checking.
I considered using
operator.transpose().mvinstead of writing outvjpbut if the matrix is tagged assymmetricthen this would end up callingjacrevinstead ofvjp.Why is this helpful?
When using
lineaxdirectly one can of course just define aDiagonalOperatorinstead of a more generalJacobianLinearOperator, but this is not always possible. For example, when usingoptimistix, the operator is instantiated within the optimisation routine and the only way to inform the optimiser about the underlying structure of the matrix is throughtags. Therefore, if the function being optimised is a primitive (e.g. an FFI) with a JVP rule that does not support batching a user is stuck. If a slow batching rule, such asvmap_method="sequential", is used the current approach is also painfully slow for large matrix sizes.Performance impact
I had initially hoped this to have a minor positive impact on performance across the board, but as ever I have massively underestimated the power of XLA. In practice, whether this PR seems to improve performance (e.g. for a
linear_solveor anoptimistix.root_find) of a pure jax function appears to fluctuate with array size. By playing around with differentXLA_FLAGSand other environment variables, my best guess is that this is mostly due to threading; avmapapplied to ajnp.eyeis threaded much more aggressively meaning that the apparent time complexity appears to be of lower order than a more direct approach. However, when I tried to eliminate threading this PR still seems to have an 8–10% negative impact on performance for array sizes > 100 on anoptimistix.root_find.Pure `jax` comparison: using `jvp` when attempting to enfore single-threadedness is about 14µs faster.
It seems self-evident that the second function should be more efficient, however with the new

thunkruntime on my Macfrom_eyeruns faster thandirect(referred to aswrappedin the diagram below, you can ignoreunwrappedandvmapas similar performance) for array sizes > ~1.5E4:Disabling the

thunkruntime (withXLA_FLAGS=--xla_cpu_use_thunk_runtime=falsewhich is reported to run faster in some circumstances) decreases the gap between the two by slightly slowing down theeyeimplementation and accelerating thedirectapproach:Going further and following all suggestions in github.com/jax-ml/jax/discussions/22739 to limit to one thread/core and we can see the

directapproach is now consistently about 14µs faster:`linear_solve` significantly faster (often >2x) for array sizes <2E4 using thunk runtime, but runs about 6–10% slower for large array sizes when disabling and attempting to enforce a single thread
Code tested
Using standard thunk runtime and

EQX_ON_ERROR=nanwe see significant speedup for array sizes < 1E4Enforcing single-thread the performance between the old and the new approaches is very similar but tracks at about 6–10% slower for larger array sizes.

(Note that
DiagonalOperatoris actually slower somehow.)Similar behaviour is observed with `optimistix.root_find` (but with more modest gains, and some hits for larger array sizes)
I compared performance for a multi-root find of the
sinfunction (withEQX_ON_ERROR=nan):Default settings (

jax0.6.1,mainvs this branch oflineax) with and without standardthunk` runtime:In both runtimes this PR improves/maintains performance by up to a factor of 2 for arrays of size up to 1E4 at which point it becomes slightly slow than the current version (by ~8%).
However, limiting to one thread as best as I can most of the noise is eliminated and the two have very similar performance time (the change tracking about 6% slower) except for an array size of 20 where the proposed change is faster:

Much more substantial performance improvement (8x or higher) is observed for primitives that only support `sequential` batching rules
This is a very contrived example, but based on very real use cases we have over at tesseract-core and tesseract-jax. I have defined a new primitive version of
sinwith ajvprule that batches sequentially and is therefore slow and doesn't benefit from compilation/threading in the same way:Code for primitives
I then ran the same tests as before but with

sin_pinstead ofjnp.sinand we can see the time complexity of the current version is almost quadratic for array sizes greater than 100 (as one would naively expect for a dense jacobian) meaning that speedups range from a factor of 2 (array size of 20) to a factor of 8 (array size of 5000) and higher:Running(This usesbenchmarks/solver_speed.pyshows a negligible improvement in the singleDiagonalsolve but a 50% faster batch solve, this could of course be down to noise as the solve is only timed once.lx.Diagonalso not relevant and probably just a fluke.)Testing done
test_diagonalsuch that operators are actually initialised with diagonal matricesNewtonandBisection) of scalar function with no batching rule and take gradients through the root solve (not possible previously) this tests bothJacobianLinearOperatorandFunctionLinearOperatorin actiondiagonalfromJacobianLinearOperatorwithjac="bwd"Happy to perform any further requested testing you see fit/necessary. I appreciate I haven't managed to test reverse-mode especially extensively.
Next steps
In a future PR, I would like to do something similar for other structures (e.g. tridiagonal) this should address the large O(n) discrepancy observed in #149 (but not the O(0) discrepancy). I believe this will be a much more consistent and meaningful gain than observed here. This PR here should likely be a lot easier to grok and reason about the concept and discuss framework/design choices (although maybe not the performance impact :) ) before building out further.