From a75689f2c81269b024e272888fe4ecb750fca65e Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 19 Nov 2021 10:32:10 +0100 Subject: [PATCH] Add moment for Simulator distribution --- pymc/distributions/simulator.py | 12 ++++++- pymc/tests/test_distributions_moments.py | 40 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/simulator.py b/pymc/distributions/simulator.py index 4c0d764231..424cf42df6 100644 --- a/pymc/distributions/simulator.py +++ b/pymc/distributions/simulator.py @@ -25,7 +25,7 @@ from scipy.spatial import cKDTree from pymc.aesaraf import floatX -from pymc.distributions.distribution import NoDistribution +from pymc.distributions.distribution import NoDistribution, _get_moment __all__ = ["Simulator"] @@ -223,6 +223,10 @@ def logp(op, value_var_list, *dist_params, **kwargs): value_var = value_var_list[0] return cls.logp(value_var, op, dist_params) + @_get_moment.register(SimulatorRV) + def get_moment(op, rv, size, *rv_inputs): + return cls.get_moment(rv, size, *rv_inputs) + cls.rv_op = sim_op return super().__new__(cls, name, *params, **kwargs) @@ -230,6 +234,12 @@ def logp(op, value_var_list, *dist_params, **kwargs): def dist(cls, *params, **kwargs): return super().dist(params, **kwargs) + @classmethod + def get_moment(cls, rv, size, *sim_inputs): + # Take the mean of 10 draws + multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape])) + return at.mean(multiple_sim, axis=0) + @classmethod def logp(cls, value, sim_op, sim_inputs): # Use a new rng to avoid non-randomness in parallel sampling diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 050755788b..e1baa8840a 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -1,12 +1,14 @@ import aesara import numpy as np import pytest +import scipy.stats as st from aesara import tensor as at from scipy import special import pymc as pm +from pymc import Simulator from pymc.distributions import ( AsymmetricLaplace, Bernoulli, @@ -1074,3 +1076,41 @@ def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected): with Model() as model: ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str) +@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str) +@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str) +def test_simulator_moment(mu, sigma, size): + def normal_sim(rng, mu, sigma, size): + return rng.normal(mu, sigma, size=size) + + with Model() as model: + x = Simulator("x", normal_sim, mu, sigma, size=size) + + fn = make_initial_point_fn( + model=model, + return_transformed=False, + default_strategy="moment", + ) + + random_draw = model["x"].eval() + result = fn(0)["x"] + assert result.shape == random_draw.shape + + # We perform a z-test between the moment and expected mean from a sample of 10 draws + # This test fails if the number of samples averaged in get_moment(Simulator) + # is much smaller than 10, but would not catch the case where the number of samples + # is higher than the expected 10 + + n = 10 # samples + expected_sample_mean = mu + expected_sample_mean_std = np.sqrt(sigma ** 2 / n) + + # Multiple test adjustment for z-test to maintain alpha=0.01 + alpha = 0.01 + alpha /= 2 * 2 * 3 # Correct for number of test permutations + alpha /= random_draw.size # Correct for distribution size + cutoff = st.norm().ppf(1 - (alpha / 2)) + + assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)