Skip to content

Conversation

@jpbrodrick89
Copy link
Contributor

@jpbrodrick89 jpbrodrick89 commented Jun 17, 2025

BREAKING CHANGE: Diagonals are no longer "extracted" from operators but rely on the promise of the diagonal tag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.

Preface

At present, both JacobianLinearOperator and FunctionLinearOperator require full materialisation even if provided with a diagonal tag. This seems self-evidently expensive (in practice it certainly can be but more often is not, see below) and requires the underlying function (which could potentially be a custom primitive) to have a batching rule. As it is currently the case that tags are considered to be a "promise" and are unchecked with no guarantee of behaviour, there are some shortcuts we can take.

Changes made

The proposal here is to use the observation that the diagonal of a matrix can be obtained by pre/post-multiplying it by a unit vector and thereby re-write the single-dispatch diagonal method for JacobianLinearOperator and FunctionLinearOperator so that the as_matrix() method is not required. For JacobianLinearOperator either jax.jvp or jax.vjp will be called depending on the jac keyword (forward-mode should always be more efficient but meeting the user's expectation will avoid issues if forward-mode is not supported such as when using a custom_vjp is used). For FunctionLinearOperator, we can just use self.mv.

However, if the matrix is not actually diagonal this identity will not hold and results may be unexpected due to contributions from off-diagonals.

Alternative design options

If you are not a fan of the breaking change and the slight performance hit for some array sizes (and devices?... might be worthwhile trying out on other machines) discussed below, would it be acceptable to introduce a diagonal_nobatch tag that implements this as an alternative? Not sure it would work out the box as sometimes the diagonal tag is added automatically for scalar functions, but we could customise the standard diagonal tag for scalar only if that is the only problem. Of course, I understand if you don't want to support un/slow-batched primitives given the trade-offs, in which case I'm happy to just maintain my own fork of lineax for my use case.

We could also "check" the tags by casting non-zero (or non-small) values as nans to enable error checking.

I considered using operator.transpose().mv instead of writing out vjp but if the matrix is tagged as symmetric then this would end up calling jacrev instead of vjp.

Why is this helpful?

When using lineax directly one can of course just define a DiagonalOperator instead of a more general JacobianLinearOperator, but this is not always possible. For example, when using optimistix, the operator is instantiated within the optimisation routine and the only way to inform the optimiser about the underlying structure of the matrix is through tags. Therefore, if the function being optimised is a primitive (e.g. an FFI) with a JVP rule that does not support batching a user is stuck. If a slow batching rule, such as vmap_method="sequential", is used the current approach is also painfully slow for large matrix sizes.

Performance impact

I had initially hoped this to have a minor positive impact on performance across the board, but as ever I have massively underestimated the power of XLA. In practice, whether this PR seems to improve performance (e.g. for a linear_solve or an optimistix.root_find) of a pure jax function appears to fluctuate with array size. By playing around with different XLA_FLAGS and other environment variables, my best guess is that this is mostly due to threading; a vmap applied to a jnp.eye is threaded much more aggressively meaning that the apparent time complexity appears to be of lower order than a more direct approach. However, when I tried to eliminate threading this PR still seems to have an 8–10% negative impact on performance for array sizes > 100 on an optimistix.root_find.

Pure `jax` comparison: using `jvp` when attempting to enfore single-threadedness is about 14µs faster.
@jax.jit
def from_eye(x):
    eye = jnp.eye(len(x), dtype=x.dtype)
    jac = jax.vmap(lambda t: jnp.cos(x)*t, out_axes=-1)(eye)
    return jnp.diag(jac)

@jax.jit
def direct(x):
    return jnp.cos(x)

It seems self-evident that the second function should be more efficient, however with the new thunk runtime on my Mac from_eye runs faster than direct (referred to as wrapped in the diagram below, you can ignore unwrapped and vmap as similar performance) for array sizes > ~1.5E4:
image

Disabling the thunk runtime (with XLA_FLAGS=--xla_cpu_use_thunk_runtime=false which is reported to run faster in some circumstances) decreases the gap between the two by slightly slowing down the eye implementation and accelerating the direct approach:
image

Going further and following all suggestions in github.com/jax-ml/jax/discussions/22739 to limit to one thread/core and we can see the direct approach is now consistently about 14µs faster:
image

`linear_solve` significantly faster (often >2x) for array sizes <2E4 using thunk runtime, but runs about 6–10% slower for large array sizes when disabling and attempting to enforce a single thread
Code tested
solver = lineax.Diagonal(well_posed=True)

def double(x, args):
    return (2.0 * jnp.ones_like(x)) * x

@jax.jit
def jac_op_to_sol(rhs):
    op = lineax.JacobianLinearOperator(double, rhs, tags=frozenset({lineax.diagonal_tag}))
    return lineax.linear_solve(op, rhs, solver, throw=False)

@jax.jit
def func_op_to_sol(rhs):
    op = lineax.FunctionLinearOperator(lambda x: double(x, None), jax.eval_shape(lambda: rhs), tags=frozenset({lineax.diagonal_tag}))
    return lineax.linear_solve(op, rhs, solver, throw=False)

Using standard thunk runtime and EQX_ON_ERROR=nan we see significant speedup for array sizes < 1E4
image

Enforcing single-thread the performance between the old and the new approaches is very similar but tracks at about 6–10% slower for larger array sizes.
image

(Note that DiagonalOperator is actually slower somehow.)

Similar behaviour is observed with `optimistix.root_find` (but with more modest gains, and some hits for larger array sizes)

I compared performance for a multi-root find of the sin function (with EQX_ON_ERROR=nan):

x = jnp.repeat(jnp.linspace(-10., 10., num=20), n // 20)
newton = optx.Newton(rtol=1e-8, atol=1e-8, linear_solver=lx.Diagonal())
optx.root_find(lambda x, args: jnp.sin(x), newton, x, max_steps=8, tags=frozenset({lx.diagonal_tag}), throw=False).value

Default settings (jax 0.6.1, main vs this branch of lineax) with and without standard thunk` runtime:
image

In both runtimes this PR improves/maintains performance by up to a factor of 2 for arrays of size up to 1E4 at which point it becomes slightly slow than the current version (by ~8%).

However, limiting to one thread as best as I can most of the noise is eliminated and the two have very similar performance time (the change tracking about 6% slower) except for an array size of 20 where the proposed change is faster:
image

Much more substantial performance improvement (8x or higher) is observed for primitives that only support `sequential` batching rules

This is a very contrived example, but based on very real use cases we have over at tesseract-core and tesseract-jax. I have defined a new primitive version of sin with a jvp rule that batches sequentially and is therefore slow and doesn't benefit from compilation/threading in the same way:

Code for primitives
import numpy as onp

import jax
import jax.numpy as jnp
from jax.interpreters import ad, batching, mlir
from jax.core import ShapedArray
from jax.extend import core

from jax._src.lib.mlir.dialects import hlo
from jax._src.ffi import ffi_batching_rule

from functools import partial

jax.config.update("jax_enable_x64", True)

sin_p = core.Primitive("mysin")
cos_mult_p = core.Primitive("cos_mult")

def sin_prim(x):
    return sin_p.bind(x)

def cos_mult_prim(*args):
    p = args[0]
    return cos_mult_p.bind(*args,
                           vmap_method="sequential",
                           result_avals=(ShapedArray(p.shape, p.dtype), ))

def cos_mult_impl(*args, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return onp.cos(args[0]) * args[1]

def cos_mult_abstract_eval(*args, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return ShapedArray(args[0].shape, args[0].dtype)

def sin_jvp(p, t):
    return sin_prim(p[0]), cos_mult_prim(p[0], t[0])

def sin_lowering(ctx, x):
    return [ hlo.SineOp(x).result ]

def cos_mult_lowering(ctx, p, t, *, vmap_method="sequential", result_avals=None):
    del vmap_method, result_avals
    return [ hlo.MulOp(hlo.CosineOp(p), t).result ]

def cos_mult_batch(args, axes, vmap_method="sequential", result_avals=None):
    as_tup, out_dims =  ffi_batching_rule(cos_mult_p, args, axes, 
                            vmap_method="sequential",
                            result_avals=(ShapedArray(args[0].shape, args[0].dtype),))
    return jnp.stack(list(as_tup)), out_dims[0]

sin_p.def_impl(onp.sin)
cos_mult_p.def_impl(onp.cos)
sin_p.def_abstract_eval(lambda x: ShapedArray(x.shape, x.dtype))
cos_mult_p.def_abstract_eval(cos_mult_abstract_eval)
mlir.register_lowering(sin_p, sin_lowering)
mlir.register_lowering(cos_mult_p, cos_mult_lowering)
ad.primitive_jvps[sin_p] = sin_jvp
# If the `cos_mult_p` batching rule isn't present this won't run on the current version
# but will with this PR
batching.primitive_batchers[cos_mult_p] = cos_mult_batch

I then ran the same tests as before but with sin_p instead of jnp.sin and we can see the time complexity of the current version is almost quadratic for array sizes greater than 100 (as one would naively expect for a dense jacobian) meaning that speedups range from a factor of 2 (array size of 20) to a factor of 8 (array size of 5000) and higher:
image

Running benchmarks/solver_speed.py shows a negligible improvement in the single Diagonal solve but a 50% faster batch solve, this could of course be down to noise as the solve is only timed once. (This uses lx.Diagonal so not relevant and probably just a fluke.)

Testing done

  • CI passes after modifying test_diagonal such that operators are actually initialised with diagonal matrices
  • Can find root (using both Newton and Bisection) of scalar function with no batching rule and take gradients through the root solve (not possible previously) this tests both JacobianLinearOperator and FunctionLinearOperator in action
  • Can obtain diagonal from JacobianLinearOperator with jac="bwd"

Happy to perform any further requested testing you see fit/necessary. I appreciate I haven't managed to test reverse-mode especially extensively.

Next steps

In a future PR, I would like to do something similar for other structures (e.g. tridiagonal) this should address the large O(n) discrepancy observed in #149 (but not the O(0) discrepancy). I believe this will be a much more consistent and meaningful gain than observed here. This PR here should likely be a lot easier to grok and reason about the concept and discuss framework/design choices (although maybe not the performance impact :) ) before building out further.

@jpbrodrick89 jpbrodrick89 changed the title Support primitives with no/slow JVP batch rule Support "diagonal" primitives with no/slow JVP batch rule Jun 17, 2025
@jpbrodrick89
Copy link
Contributor Author

I think I prefer this implementation (latest commit) with unravel and ones rather than map and ones_like, it seems more consistent with what is done elsewhere in the codebase (and in my tridiagonal PR) and probably more efficient for complex PyTrees. Performance is essentially identical for my test with a single 1D array.

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jun 24, 2025

I finally realised why jacobian didn't run consistently slower. This is because the computation was dominated by "transcendental" evaluations rather the MAC's/pure FLOPs. The cost of a transcendental evaluation depends on the evaluation and for sine and cosine evaluating it at the zero's of the identity matrix is very, very cheap meaning that actually computing them did not provide great cost and the more aggressive threading of jacobian often outweighed this. If I instead use a transcendental equation that is ever so slightly less trivial to evaluate at 0—jnp.exp( - (x - 1.0)**2 / 2.0)—I get much more consistent and convincing results (typically 1.5–2.5x faster for n>1E3 with default jax settings):

image

image

Would it be helpful to re-write the PR message with updated benchmarks and prose in light of this?

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jun 24, 2025

That reasoning was completely wrong again, the derivative is not evaluated at the tangent vectors just multiplied by them. I think the reduced FLOPs due to matrix multiplication is only more noticeable on the example above because jnp.exp is a bit cheaper than jnp.sin. In general, jax.jacobian is very very efficient for unary functions.

However, it is easy to fool jit by breaking up the function as it then needs to keep track of which function is applied to which index. This is actually not a very far fetched scenario if one is solving two independent ODE's for example.

def myfunc(x):
    halfway = len(x) // 2
    return jnp.concatenate([jnp.sin(x[:halfway]), jnp.cos(x[halfway:])])

image

Here we seen an O(n) speed exceeding 1E4 for array sizes of 4E4 which is huge.

In general I see no significant adverse impact of this and some very pronounced positive impacts in realistic use cases.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I really like seeing a colouring approach like this!

I will note however that lx.diagonal is documented as Extracts the diagonal from a linear operator, and returns a vector, which is meant to include extracting the diagonal from nondiagonal operators. I think the implementations you have here should check with is_diagonal to determine whether to dispatch to the new or the old implementation.



@diagonal.register(MatrixLinearOperator)
@diagonal.register(PyTreeLinearOperator)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note, I think we could make the PyTreeLinearOperator case a little more efficient by calling jnp.diag on each 'diagonal leaf', and concatenate those together. (You don't have to do that here, just spotting it.)

elif operator.jac == "bwd":
fn = _NoAuxOut(_NoAuxIn(operator.fn, operator.args))
_, vjp_fun = jax.vjp(fn, operator.x)
diag_as_pytree = vjp_fun(unravel(basis))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fail for operators with different input and output structures. (They might still be mathematically square, just split up into a pytree in different ways / with different dtypes.) This needs to be a basis formed from operator.out_structure().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a test for mismatched input and output structure and looks like these are not allowed to be diagonal in the current version of lineax. Shall we leave relaxing that assumption and fixing this for future work? This implementation should work fine as long as the assumption holds.

FAILED tests/test_operator.py::test_is_symmetric[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_symmetric[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...
FAILED tests/test_operator.py::test_is_diagonal[float64] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=float64), ShapeDtypeStruct(shape=(), dtype=float64)) and output structure (ShapeDtypeStruct(shape=(), dty...
FAILED tests/test_operator.py::test_is_diagonal[complex128] - ValueError: Symmetric matrices must have matching input and output structures. Got input structure (ShapeDtypeStruct(shape=(2,), dtype=complex128), ShapeDtypeStruct(shape=(), dtype=complex128)) and output structure (ShapeDtypeStruct(shape=(...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, these are PyTreeLinearOperators but I need to address for JacobianLinearOperators

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually got exactly the same error with JacobianLinearOperator, if this is not expected I can push the test I added to test/helpers to investigate further

@_operators_append
def make_nontrivial_jac_operator(getkey, matrix, tags):
    # makes a Jacobian linear operator from matrix with
    # input structure {"array", (in_size -1,), "scalar": ()}
    # output structure ((), (out_size - 1,))
    out_size, in_size = matrix.shape
    x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
    a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
    b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
    c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
    fn_tmp = lambda x, _: a + b @ x + c @ x**2
    jac = jax.jacfwd(fn_tmp, holomorphic=jnp.iscomplexobj(x))(x, None)
    diff = matrix - jac

    def fn(x, args):
        x_flat = jnp.concatenate([x["array"], x["scalar"][jnp.newaxis]])
        y_flat = a + (b + diff) @ x_flat + c @ x_flat**2
        y = [y_flat[0], y_flat[1:]]
        return y

    return lx.JacobianLinearOperator(fn, {"array": x[:-1], "scalar": x[-1]}, None, tags)

@jpbrodrick89
Copy link
Contributor Author

Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think diagonal or tridagonal is ever used by operators that don't have the corresponding tag, so we have an option to just change the doc string instead of you like? While I could certainly imagine situations where the current documented usage could be convenient when testing out stability of various differencing schemes (e.g. extracting a tridiagonal to treat implicitly and treating the rest explicitly) but in some cases there should be a more efficient way to express the operator in two part.

@patrick-kidger
Copy link
Owner

Happy to address these during the week, just wanted to check: I did a quick trawl through the code and I don't think diagonal or tridagonal is ever used by operators that don't have the corresponding tag, so we have an option to just change the doc string instead of you like?

However they're actually a standalone public API themselves, which could be used independent of the solvers. :)

@jpbrodrick89
Copy link
Contributor Author

Sorry for abandoning this for so long as the day job took over. I returned as it was found to provide orders of magnitude impact on a problem I was working (root finding over multiple interpolations) even for small array sizes (200–2000). I have addressed the main point of retaining extraction of the diagonal when the diagonal tag is missing. However, your other two comments seem at odds with each other: either we ensure input/output structures match (which is actually the case, see above) enabling us to extract diagonal leafs for PyTreeLinearOperator and leaving the JacobianLinearOperator vjp implementation as it is, or we relax this assumption and are no longer able to extract diagonal leafs as there is not guaranteed to be such an object.

@patrick-kidger
Copy link
Owner

Sorry for the long delay getting back to you, some personal life things took over for a while.

So, now to actually answer your question: good point.

I imagine we could probably do the 'diagonal leaf' approach when the structures match, and go for the more expensive approach when they don't?

@jpbrodrick89
Copy link
Contributor Author

No worries, hope you're managing alright. Sorry to miss you at DiffSys, but looking forward to catching up with Johanna!

Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated?

Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Dec 22, 2025

Just to check I understand correctly, shall we go forward with the status quo that Diagonal (as for Symmetric) operators must ALWAYS have their input and output structures and continue to raise ValueError's when this is violated?

Therefore, we do not need to touch JacobianLinearOperator but just need to adopt the diagonal leaf approach for PyTreeLinearOperators and we're done?

Yup, I think that'd be reasonable! Let me know when this PR is ready and we'll get this merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants