diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 8d549d767da3..d6c30be8946f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -2779,7 +2779,7 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.abs_p) +@register_lowering_rule(lax.abs_p, kernel_types=[*tpu_core.KernelType]) def _abs_lowering_rule(ctx: LoweringRuleContext, x): (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): @@ -2789,7 +2789,9 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): raise NotImplementedError(aval_out.dtype) -@register_lowering_rule(lax.neg_p, ensure_mlir_values=False) +@register_lowering_rule( + lax.neg_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False +) def _neg_lowering_rule(ctx: LoweringRuleContext, x): (x_aval,) = ctx.avals_in new_ctx = ctx.replace( diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 2878e6ed1e6a..e43706f85899 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -1614,14 +1614,16 @@ def kernel(in_ref, o_ref, scratch_ref): np.testing.assert_array_equal(f(x), x) - def test_exp(self): + @parameterized.named_parameters( + ("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs)) + def test_unary_ops(self, op): if not jtu.if_cloud_tpu_at_least(2025, 11, 30): self.skipTest("Test requires a newer libtpu") x = jnp.arange(8, dtype=jnp.float32) def sc_exp_kernel(x_hbm_ref, out_ref): - out_ref[...] = jnp.exp(x_hbm_ref[...]) + out_ref[...] = op(x_hbm_ref[...]) result = pl.pallas_call( sc_exp_kernel, @@ -1630,7 +1632,7 @@ def sc_exp_kernel(x_hbm_ref, out_ref): ), out_shape=x, )(x) - np.testing.assert_array_equal(result, jnp.exp(x)) + np.testing.assert_array_equal(result, op(x)) @parameterized.product(dtype=[np.int32, np.float32]) def test_vector_gather(self, dtype):