Skip to content

Commit 606bb08

Browse files
brianwa84Google-ML-Automation
authored andcommitted
[Pallas:SC] Add support for neg_p, abs_p.
PiperOrigin-RevId: 835183262
1 parent 0c32991 commit 606bb08

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2779,7 +2779,7 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
27792779
raise NotImplementedError(aval_out.dtype)
27802780

27812781

2782-
@register_lowering_rule(lax.abs_p)
2782+
@register_lowering_rule(lax.abs_p, kernel_types=[*tpu_core.KernelType])
27832783
def _abs_lowering_rule(ctx: LoweringRuleContext, x):
27842784
(aval_out,) = ctx.avals_out
27852785
if jnp.issubdtype(aval_out.dtype, jnp.integer):
@@ -2789,7 +2789,9 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x):
27892789
raise NotImplementedError(aval_out.dtype)
27902790

27912791

2792-
@register_lowering_rule(lax.neg_p, ensure_mlir_values=False)
2792+
@register_lowering_rule(
2793+
lax.neg_p, kernel_types=[*tpu_core.KernelType], ensure_mlir_values=False
2794+
)
27932795
def _neg_lowering_rule(ctx: LoweringRuleContext, x):
27942796
(x_aval,) = ctx.avals_in
27952797
new_ctx = ctx.replace(

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,14 +1614,16 @@ def kernel(in_ref, o_ref, scratch_ref):
16141614

16151615
np.testing.assert_array_equal(f(x), x)
16161616

1617-
def test_exp(self):
1617+
@parameterized.named_parameters(
1618+
("exp", jnp.exp), ("neg", lambda x: -x), ("abs", jnp.abs))
1619+
def test_unary_ops(self, op):
16181620
if not jtu.if_cloud_tpu_at_least(2025, 11, 30):
16191621
self.skipTest("Test requires a newer libtpu")
16201622

16211623
x = jnp.arange(8, dtype=jnp.float32)
16221624

16231625
def sc_exp_kernel(x_hbm_ref, out_ref):
1624-
out_ref[...] = jnp.exp(x_hbm_ref[...])
1626+
out_ref[...] = op(x_hbm_ref[...])
16251627

16261628
result = pl.pallas_call(
16271629
sc_exp_kernel,
@@ -1630,7 +1632,7 @@ def sc_exp_kernel(x_hbm_ref, out_ref):
16301632
),
16311633
out_shape=x,
16321634
)(x)
1633-
np.testing.assert_array_equal(result, jnp.exp(x))
1635+
np.testing.assert_array_equal(result, op(x))
16341636

16351637
@parameterized.product(dtype=[np.int32, np.float32])
16361638
def test_vector_gather(self, dtype):

0 commit comments

Comments
 (0)