Skip to content

jnp.zeros does not support float0 as a dtype #4433

@danieldjohnson

Description

@danieldjohnson

After #4039, it's no longer possible to pass an integer zero tangent to jax.jvp, it instead needs to be a float0 tangent. However, jnp.zeros doesn't actually support creating arrays of dtype float0, giving:

TypeError: JAX only supports number and bool dtypes, got dtype [('float0', 'V')]

Using np.zeros works fine, but it seems like jnp.zeros should work here too, especially since float0 is a jax-specific thing.

(For context, I have some code that looks like

def some_function_custom_jvp_rule(primals, tangents):
    x, y = primals
    dx, _ = tangents
    _, f_jvp_x = jax.jvp(lambda x: foo(x, y), (x,), (dx,))
    # ... do something with f_jvp_x

and sometimes x is a pytree containing integers. Before, this just worked; now, I have to go in and change the dtypes, because custom_jvp gives me integer tangents but jax.jvp expects float0 tangents. Perhaps this is a bug too? I'm not sure.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions