Skip to content

Commit 64d8396

Browse files
ricardoV94twiecki
authored andcommitted
Add moment for Simulator distribution
1 parent ac5126b commit 64d8396

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

pymc/distributions/simulator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from scipy.spatial import cKDTree
2626

2727
from pymc.aesaraf import floatX
28-
from pymc.distributions.distribution import NoDistribution
28+
from pymc.distributions.distribution import NoDistribution, _get_moment
2929

3030
__all__ = ["Simulator"]
3131

@@ -223,13 +223,23 @@ def logp(op, value_var_list, *dist_params, **kwargs):
223223
value_var = value_var_list[0]
224224
return cls.logp(value_var, op, dist_params)
225225

226+
@_get_moment.register(SimulatorRV)
227+
def get_moment(op, rv, size, *rv_inputs):
228+
return cls.get_moment(rv, size, *rv_inputs)
229+
226230
cls.rv_op = sim_op
227231
return super().__new__(cls, name, *params, **kwargs)
228232

229233
@classmethod
230234
def dist(cls, *params, **kwargs):
231235
return super().dist(params, **kwargs)
232236

237+
@classmethod
238+
def get_moment(cls, rv, size, *sim_inputs):
239+
# Take the mean of 10 draws
240+
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
241+
return at.mean(multiple_sim, axis=0)
242+
233243
@classmethod
234244
def logp(cls, value, sim_op, sim_inputs):
235245
# Use a new rng to avoid non-randomness in parallel sampling

pymc/tests/test_distributions_moments.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import aesara
22
import numpy as np
33
import pytest
4+
import scipy.stats as st
45

56
from aesara import tensor as at
67
from scipy import special
78

89
import pymc as pm
910

11+
from pymc import Simulator
1012
from pymc.distributions import (
1113
AsymmetricLaplace,
1214
Bernoulli,
@@ -1074,3 +1076,41 @@ def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected):
10741076
with Model() as model:
10751077
ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size)
10761078
assert_moment_is_expected(model, expected)
1079+
1080+
1081+
@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
1082+
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
1083+
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)
1084+
def test_simulator_moment(mu, sigma, size):
1085+
def normal_sim(rng, mu, sigma, size):
1086+
return rng.normal(mu, sigma, size=size)
1087+
1088+
with Model() as model:
1089+
x = Simulator("x", normal_sim, mu, sigma, size=size)
1090+
1091+
fn = make_initial_point_fn(
1092+
model=model,
1093+
return_transformed=False,
1094+
default_strategy="moment",
1095+
)
1096+
1097+
random_draw = model["x"].eval()
1098+
result = fn(0)["x"]
1099+
assert result.shape == random_draw.shape
1100+
1101+
# We perform a z-test between the moment and expected mean from a sample of 10 draws
1102+
# This test fails if the number of samples averaged in get_moment(Simulator)
1103+
# is much smaller than 10, but would not catch the case where the number of samples
1104+
# is higher than the expected 10
1105+
1106+
n = 10 # samples
1107+
expected_sample_mean = mu
1108+
expected_sample_mean_std = np.sqrt(sigma ** 2 / n)
1109+
1110+
# Multiple test adjustment for z-test to maintain alpha=0.01
1111+
alpha = 0.01
1112+
alpha /= 2 * 2 * 3 # Correct for number of test permutations
1113+
alpha /= random_draw.size # Correct for distribution size
1114+
cutoff = st.norm().ppf(1 - (alpha / 2))
1115+
1116+
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)

0 commit comments

Comments
 (0)