diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index abf3f528bf..4779a0ac38 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -105,7 +105,7 @@ Here’s an example for the `Eye`\ `Op`: @jax_funcify.register(Eye) - def jax_funcify_Eye(op): + def jax_funcify_Eye(op, **kwargs): # Obtain necessary "static" attributes from the Op being converted dtype = op.dtype