Skip to content

Commit 4459199

Browse files
committed
Don't include local_uint_constant_indices rewrite in JAX mode due to XLA bug
1 parent 14d2454 commit 4459199

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pytensor/compile/mode.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,14 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
451451
JAXLinker(),
452452
RewriteDatabaseQuery(
453453
include=["fast_run", "jax"],
454-
exclude=["cxx_only", "BlasOpt", "fusion", "inplace"],
454+
# TODO: "local_uint_constant_indices" can be reintroduced once https://github.com/google/jax/issues/16836 is fixed.
455+
exclude=[
456+
"cxx_only",
457+
"BlasOpt",
458+
"fusion",
459+
"inplace",
460+
"local_uint_constant_indices",
461+
],
455462
),
456463
)
457464
NUMBA = Mode(

0 commit comments

Comments
 (0)