Skip to content

Fix bug in JAX cloning of RNG shared variables #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def assert_size_argument_jax_compatible(node):

"""
size = node.inputs[1]
size_op = size.owner.op
if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)):
size_node = size.owner
if (size_node is not None) and (
not isinstance(size_node.op, (Shape, Shape_i, JAXShapeTuple))
):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)


Expand Down
9 changes: 8 additions & 1 deletion pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
input_storage[input_storage.index(old_inp_storage)] = new_inp_storage
# Find index of old_inp_storage in input_storage
for input_storage_idx, input_storage_item in enumerate(input_storage):
# We have to establish equality based on identity because input_storage may contain numpy arrays
if input_storage_item is old_inp_storage:
break
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
fgraph.remove_input(
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
)
Expand Down
113 changes: 78 additions & 35 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re

import numpy as np
import pytest
import scipy.stats as stats
Expand All @@ -22,6 +20,13 @@
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402


def random_function(*args, **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
return function(*args, **kwargs)


def test_random_RandomStream():
"""Two successive calls of a compiled graph using `RandomStream` should
return different values.
Expand All @@ -30,11 +35,7 @@ def test_random_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()

with pytest.warns(
UserWarning,
match=r"The RandomType SharedVariables \[.+\] will not be used",
):
fn = function([], out, mode=jax_mode)
fn = random_function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()

Expand All @@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
rng = shared(original_value, name="original_rng", borrow=False)
next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs

with pytest.warns(
UserWarning,
match=re.escape(
"The RandomType SharedVariables [original_rng] will not be used"
),
):
f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode)
f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode)
assert f() != f()

# Check that original rng variable content was not overwritten when calling jax_typify
Expand All @@ -63,6 +58,40 @@ def test_random_updates(rng_ctor):
)


def test_random_updates_input_storage_order():
"""Test case described in issue #314.

This happened when we tried to update the input storage after we clone the shared RNG.
We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained
numpy arrays before the RNG value, which would fail the equality check.

"""
pt_rng = RandomStream(1)

batchshape = (3, 1, 4, 4)
inp_shared = pytensor.shared(
np.zeros(batchshape, dtype="float64"), name="inp_shared"
)

inp = at.tensor4(dtype="float64", name="inp")
inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5)

# This function replaces inp by input_shared in the update expression
# This is what caused the RNG to appear later than inp_shared in the input_storage

fn = random_function(
inputs=[],
outputs=[],
updates={inp_shared: inp_update},
givens={inp: inp_shared},
mode="JAX",
)
fn()
np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3)
fn()
np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3)


@pytest.mark.parametrize(
"rv_op, dist_params, base_size, cdf_name, params_conv",
[
Expand Down Expand Up @@ -420,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
else:
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng)
g_fn = function(dist_params, g, mode=jax_mode)
g_fn = random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
Expand All @@ -444,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -455,7 +484,7 @@ def test_random_mvnormal():
mu = np.ones(4)
cov = np.eye(4)
g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)

Expand All @@ -470,7 +499,7 @@ def test_random_mvnormal():
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1)

Expand All @@ -480,29 +509,29 @@ def test_random_choice():
num_samples = 10000
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(4), size=num_samples, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)

# `replace=False` produces unique results
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99

# We can pass an array with probabilities
rng = shared(np.random.RandomState(123))
g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))


def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)

Expand All @@ -511,7 +540,7 @@ def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
g = at.random.permutation(array, rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
permuted = g_fn()
with pytest.raises(AssertionError):
np.testing.assert_allclose(array, permuted)
Expand All @@ -521,7 +550,7 @@ def test_random_geometric():
rng = shared(np.random.RandomState(123))
p = np.array([0.3, 0.7])
g = at.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1)
Expand All @@ -532,7 +561,7 @@ def test_negative_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -546,7 +575,7 @@ def test_binomial():
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = at.random.binomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1)
Expand All @@ -561,7 +590,7 @@ def test_beta_binomial():
a = np.array([1.5, 13])
b = np.array([0.5, 9])
g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -579,7 +608,7 @@ def test_multinomial():
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
Expand All @@ -595,7 +624,7 @@ def test_vonmises_mu_outside_circle():
mu = np.array([-30, 40])
kappa = np.array([100, 10])
g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
g_fn = function([], g, mode=jax_mode)
g_fn = random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(
samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1
Expand Down Expand Up @@ -641,7 +670,10 @@ def rng_fn(cls, rng, size):
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)

with pytest.raises(NotImplementedError):
compare_jax_and_py(fgraph, [])
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])


def test_random_custom_implementation():
Expand Down Expand Up @@ -672,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters):
rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
compare_jax_and_py(fgraph, [])
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])


def test_random_concrete_shape():
Expand All @@ -689,7 +724,15 @@ def test_random_concrete_shape():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape, rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(x_at, 1, rng=rng)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2, 3)


Expand All @@ -708,7 +751,7 @@ def test_random_concrete_shape_subtensor():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (3,)


Expand All @@ -724,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
rng = shared(np.random.RandomState(123))
x_at = at.dmatrix()
out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng)
jax_fn = function([x_at], out, mode=jax_mode)
jax_fn = random_function([x_at], out, mode=jax_mode)
assert jax_fn(np.ones((2, 3))).shape == (2,)


Expand All @@ -735,5 +778,5 @@ def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
size_at = at.scalar()
out = at.random.normal(0, 1, size=size_at, rng=rng)
jax_fn = function([size_at], out, mode=jax_mode)
jax_fn = random_function([size_at], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)