Skip to content

Commit a384d94

Browse files
committed
Provide JAX Ops from Optional tfp dependency
1 parent 8ac8342 commit a384d94

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ jobs:
145145
# PyTensor next, pip installs a lower version of numpy via the PyPI.
146146
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
147147
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
148-
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro; fi
148+
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro tensorflow-probability; fi
149149
pip install -e ./
150150
mamba list && pip freeze
151151
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

pytensor/link/jax/dispatch/scalar.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from typing import Optional, Callable
23

34
import jax
45
import jax.numpy as jnp
@@ -18,7 +19,21 @@
1819
Second,
1920
Sub,
2021
)
21-
from pytensor.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
22+
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
23+
24+
25+
def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
26+
try:
27+
import tensorflow_probability.substrates.jax.math as tfp_jax_math
28+
except ModuleNotFoundError:
29+
raise NotImplementedError(
30+
f"No JAX implementation for Op {op.name}. "
31+
"Implementation is available if TensorFlow Probability is installed"
32+
)
33+
34+
if jax_op_name is None:
35+
jax_op_name = op.name
36+
return getattr(tfp_jax_math, jax_op_name)
2237

2338

2439
def check_if_inputs_scalars(node):
@@ -211,6 +226,24 @@ def erfinv(x):
211226
return erfinv
212227

213228

229+
@jax_funcify.register(Erfcx)
230+
@jax_funcify.register(Erfcinv)
231+
def jax_funcify_from_tfp(op, **kwargs):
232+
tfp_jax_op = try_import_tfp_jax_op(op)
233+
234+
return tfp_jax_op
235+
236+
237+
@jax_funcify.register(Iv)
238+
def jax_funcify_Iv(op, **kwargs):
239+
ive = try_import_tfp_jax_op(op, jax_op_name="bessel_ive")
240+
241+
def iv(v, x):
242+
return ive(v, x) / jnp.exp(-jnp.abs(jnp.real(x)))
243+
244+
return iv
245+
246+
214247
@jax_funcify.register(Log1mexp)
215248
def jax_funcify_Log1mexp(op, node, **kwargs):
216249
def log1mexp(x):

tests/link/jax/test_scalar.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.graph.op import get_test_value
99
from pytensor.scalar.basic import Composite
10+
from pytensor.tensor import as_tensor
1011
from pytensor.tensor.elemwise import Elemwise
1112
from pytensor.tensor.math import all as at_all
1213
from pytensor.tensor.math import (
1314
cosh,
1415
erf,
1516
erfc,
17+
erfcinv,
18+
erfcx,
1619
erfinv,
20+
iv,
1721
log,
1822
log1mexp,
1923
psi,
@@ -28,6 +32,14 @@
2832
from pytensor.link.jax.dispatch import jax_funcify
2933

3034

35+
try:
36+
pass
37+
38+
TFP_INSTALLED = True
39+
except ModuleNotFoundError:
40+
TFP_INSTALLED = False
41+
42+
3143
def test_second():
3244
a0 = scalar("a0")
3345
b = scalar("b")
@@ -134,6 +146,23 @@ def test_erfinv():
134146
compare_jax_and_py(fg, [0.95])
135147

136148

149+
@pytest.mark.parametrize(
150+
"op, test_values",
151+
[
152+
(erfcx, (0.7,)),
153+
(erfcinv, (0.7,)),
154+
(iv, (0.3, 0.7)),
155+
],
156+
)
157+
@pytest.mark.skipif(not TFP_INSTALLED, reason="Test requires tensorflow-probability")
158+
def test_tfp_ops(op, test_values):
159+
inputs = [as_tensor(test_value).type() for test_value in test_values]
160+
output = op(*inputs)
161+
162+
fg = FunctionGraph(inputs, [output])
163+
compare_jax_and_py(fg, test_values)
164+
165+
137166
def test_psi():
138167
x = scalar("x")
139168
out = psi(x)

0 commit comments

Comments
 (0)