-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Description
After #4039, it's no longer possible to pass an integer zero tangent to jax.jvp, it instead needs to be a float0 tangent. However, jnp.zeros doesn't actually support creating arrays of dtype float0, giving:
TypeError: JAX only supports number and bool dtypes, got dtype [('float0', 'V')]
Using np.zeros works fine, but it seems like jnp.zeros should work here too, especially since float0 is a jax-specific thing.
(For context, I have some code that looks like
def some_function_custom_jvp_rule(primals, tangents):
x, y = primals
dx, _ = tangents
_, f_jvp_x = jax.jvp(lambda x: foo(x, y), (x,), (dx,))
# ... do something with f_jvp_x
and sometimes x is a pytree containing integers. Before, this just worked; now, I have to go in and change the dtypes, because custom_jvp gives me integer tangents but jax.jvp expects float0 tangents. Perhaps this is a bug too? I'm not sure.)
Metadata
Metadata
Assignees
Labels
No labels