diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 745dc4753c..0981234db0 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -45,8 +45,10 @@ def assert_size_argument_jax_compatible(node): """ size = node.inputs[1] - size_op = size.owner.op - if not isinstance(size_op, (Shape, Shape_i, JAXShapeTuple)): + size_node = size.owner + if (size_node is not None) and ( + not isinstance(size_node.op, (Shape, Shape_i, JAXShapeTuple)) + ): raise NotImplementedError(SIZE_NOT_COMPATIBLE) diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index a4e6072587..2d75e76d5c 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -44,7 +44,14 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): new_inp_storage = [new_inp.get_value(borrow=True)] storage_map[new_inp] = new_inp_storage old_inp_storage = storage_map.pop(old_inp) - input_storage[input_storage.index(old_inp_storage)] = new_inp_storage + # Find index of old_inp_storage in input_storage + for input_storage_idx, input_storage_item in enumerate(input_storage): + # We have to establish equality based on identity because input_storage may contain numpy arrays + if input_storage_item is old_inp_storage: + break + else: # no break + raise ValueError() + input_storage[input_storage_idx] = new_inp_storage fgraph.remove_input( fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert" ) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 21d150504a..54e4e09307 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -1,5 +1,3 @@ -import re - import numpy as np import pytest import scipy.stats as stats @@ -22,6 +20,13 @@ from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 +def random_function(*args, **kwargs): + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + return function(*args, **kwargs) + + def test_random_RandomStream(): """Two successive calls of a compiled graph using `RandomStream` should return different values. @@ -30,11 +35,7 @@ def test_random_RandomStream(): srng = RandomStream(seed=123) out = srng.normal() - srng.normal() - with pytest.warns( - UserWarning, - match=r"The RandomType SharedVariables \[.+\] will not be used", - ): - fn = function([], out, mode=jax_mode) + fn = random_function([], out, mode=jax_mode) jax_res_1 = fn() jax_res_2 = fn() @@ -47,13 +48,7 @@ def test_random_updates(rng_ctor): rng = shared(original_value, name="original_rng", borrow=False) next_rng, x = at.random.normal(name="x", rng=rng).owner.outputs - with pytest.warns( - UserWarning, - match=re.escape( - "The RandomType SharedVariables [original_rng] will not be used" - ), - ): - f = pytensor.function([], [x], updates={rng: next_rng}, mode=jax_mode) + f = random_function([], [x], updates={rng: next_rng}, mode=jax_mode) assert f() != f() # Check that original rng variable content was not overwritten when calling jax_typify @@ -63,6 +58,40 @@ def test_random_updates(rng_ctor): ) +def test_random_updates_input_storage_order(): + """Test case described in issue #314. + + This happened when we tried to update the input storage after we clone the shared RNG. + We used to call `input_storage.index(old_input_storage)` which would fail when the input_storage contained + numpy arrays before the RNG value, which would fail the equality check. + + """ + pt_rng = RandomStream(1) + + batchshape = (3, 1, 4, 4) + inp_shared = pytensor.shared( + np.zeros(batchshape, dtype="float64"), name="inp_shared" + ) + + inp = at.tensor4(dtype="float64", name="inp") + inp_update = inp + pt_rng.normal(size=inp.shape, loc=5, scale=1e-5) + + # This function replaces inp by input_shared in the update expression + # This is what caused the RNG to appear later than inp_shared in the input_storage + + fn = random_function( + inputs=[], + outputs=[], + updates={inp_shared: inp_update}, + givens={inp: inp_shared}, + mode="JAX", + ) + fn() + np.testing.assert_allclose(inp_shared.get_value(), 5, rtol=1e-3) + fn() + np.testing.assert_allclose(inp_shared.get_value(), 10, rtol=1e-3) + + @pytest.mark.parametrize( "rv_op, dist_params, base_size, cdf_name, params_conv", [ @@ -420,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c else: rng = shared(np.random.RandomState(29402)) g = rv_op(*dist_params, size=(10_000,) + base_size, rng=rng) - g_fn = function(dist_params, g, mode=jax_mode) + g_fn = random_function(dist_params, g, mode=jax_mode) samples = g_fn( *[ i.tag.test_value @@ -444,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c def test_random_bernoulli(size): rng = shared(np.random.RandomState(123)) g = at.random.bernoulli(0.5, size=(1000,) + size, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -455,7 +484,7 @@ def test_random_mvnormal(): mu = np.ones(4) cov = np.eye(4) g = at.random.multivariate_normal(mu, cov, size=(10000,), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1) @@ -470,7 +499,7 @@ def test_random_mvnormal(): def test_random_dirichlet(parameter, size): rng = shared(np.random.RandomState(123)) g = at.random.dirichlet(parameter, size=(1000,) + size, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 0.5, 1) @@ -480,21 +509,21 @@ def test_random_choice(): num_samples = 10000 rng = shared(np.random.RandomState(123)) g = at.random.choice(np.arange(4), size=num_samples, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2) # `replace=False` produces unique results rng = shared(np.random.RandomState(123)) g = at.random.choice(np.arange(100), replace=False, size=99, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() assert len(np.unique(samples)) == 99 # We can pass an array with probabilities rng = shared(np.random.RandomState(123)) g = at.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples, np.zeros(10)) @@ -502,7 +531,7 @@ def test_random_choice(): def test_random_categorical(): rng = shared(np.random.RandomState(123)) g = at.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1) @@ -511,7 +540,7 @@ def test_random_permutation(): array = np.arange(4) rng = shared(np.random.RandomState(123)) g = at.random.permutation(array, rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) permuted = g_fn() with pytest.raises(AssertionError): np.testing.assert_allclose(array, permuted) @@ -521,7 +550,7 @@ def test_random_geometric(): rng = shared(np.random.RandomState(123)) p = np.array([0.3, 0.7]) g = at.random.geometric(p, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), 1 / p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt((1 - p) / p**2), rtol=0.1) @@ -532,7 +561,7 @@ def test_negative_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = at.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * (1 - p) / p, rtol=0.1) np.testing.assert_allclose( @@ -546,7 +575,7 @@ def test_binomial(): n = np.array([10, 40]) p = np.array([0.3, 0.7]) g = at.random.binomial(n, p, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) @@ -561,7 +590,7 @@ def test_beta_binomial(): a = np.array([1.5, 13]) b = np.array([0.5, 9]) g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) np.testing.assert_allclose( @@ -579,7 +608,7 @@ def test_multinomial(): n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose( @@ -595,7 +624,7 @@ def test_vonmises_mu_outside_circle(): mu = np.array([-30, 40]) kappa = np.array([100, 10]) g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) - g_fn = function([], g, mode=jax_mode) + g_fn = random_function([], g, mode=jax_mode) samples = g_fn() np.testing.assert_allclose( samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 @@ -641,7 +670,10 @@ def rng_fn(cls, rng, size): fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) with pytest.raises(NotImplementedError): - compare_jax_and_py(fgraph, []) + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + compare_jax_and_py(fgraph, []) def test_random_custom_implementation(): @@ -672,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters): rng = shared(np.random.RandomState(123)) out = nonexistentrv(rng=rng) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) - compare_jax_and_py(fgraph, []) + with pytest.warns( + UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" + ): + compare_jax_and_py(fgraph, []) def test_random_concrete_shape(): @@ -689,7 +724,15 @@ def test_random_concrete_shape(): rng = shared(np.random.RandomState(123)) x_at = at.dmatrix() out = at.random.normal(0, 1, size=x_at.shape, rng=rng) - jax_fn = function([x_at], out, mode=jax_mode) + jax_fn = random_function([x_at], out, mode=jax_mode) + assert jax_fn(np.ones((2, 3))).shape == (2, 3) + + +def test_random_concrete_shape_from_param(): + rng = shared(np.random.RandomState(123)) + x_at = at.dmatrix() + out = at.random.normal(x_at, 1, rng=rng) + jax_fn = random_function([x_at], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2, 3) @@ -708,7 +751,7 @@ def test_random_concrete_shape_subtensor(): rng = shared(np.random.RandomState(123)) x_at = at.dmatrix() out = at.random.normal(0, 1, size=x_at.shape[1], rng=rng) - jax_fn = function([x_at], out, mode=jax_mode) + jax_fn = random_function([x_at], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (3,) @@ -724,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple(): rng = shared(np.random.RandomState(123)) x_at = at.dmatrix() out = at.random.normal(0, 1, size=(x_at.shape[0],), rng=rng) - jax_fn = function([x_at], out, mode=jax_mode) + jax_fn = random_function([x_at], out, mode=jax_mode) assert jax_fn(np.ones((2, 3))).shape == (2,) @@ -735,5 +778,5 @@ def test_random_concrete_shape_graph_input(): rng = shared(np.random.RandomState(123)) size_at = at.scalar() out = at.random.normal(0, 1, size=size_at, rng=rng) - jax_fn = function([size_at], out, mode=jax_mode) + jax_fn = random_function([size_at], out, mode=jax_mode) assert jax_fn(10).shape == (10,)