Skip to content

Add moment for Simulator distribution #5208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from scipy.spatial import cKDTree

from pymc.aesaraf import floatX
from pymc.distributions.distribution import NoDistribution
from pymc.distributions.distribution import NoDistribution, _get_moment

__all__ = ["Simulator"]

Expand Down Expand Up @@ -223,13 +223,23 @@ def logp(op, value_var_list, *dist_params, **kwargs):
value_var = value_var_list[0]
return cls.logp(value_var, op, dist_params)

@_get_moment.register(SimulatorRV)
def get_moment(op, rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = sim_op
return super().__new__(cls, name, *params, **kwargs)

@classmethod
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)

@classmethod
def get_moment(cls, rv, size, *sim_inputs):
# Take the mean of 10 draws
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
return at.mean(multiple_sim, axis=0)

@classmethod
def logp(cls, value, sim_op, sim_inputs):
# Use a new rng to avoid non-randomness in parallel sampling
Expand Down
40 changes: 40 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import aesara
import numpy as np
import pytest
import scipy.stats as st

from aesara import tensor as at
from scipy import special

import pymc as pm

from pymc import Simulator
from pymc.distributions import (
AsymmetricLaplace,
Bernoulli,
Expand Down Expand Up @@ -1074,3 +1076,41 @@ def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected):
with Model() as model:
ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)
def test_simulator_moment(mu, sigma, size):
def normal_sim(rng, mu, sigma, size):
return rng.normal(mu, sigma, size=size)

with Model() as model:
x = Simulator("x", normal_sim, mu, sigma, size=size)

fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
)

random_draw = model["x"].eval()
result = fn(0)["x"]
assert result.shape == random_draw.shape

# We perform a z-test between the moment and expected mean from a sample of 10 draws
# This test fails if the number of samples averaged in get_moment(Simulator)
# is much smaller than 10, but would not catch the case where the number of samples
# is higher than the expected 10

n = 10 # samples
expected_sample_mean = mu
expected_sample_mean_std = np.sqrt(sigma ** 2 / n)

# Multiple test adjustment for z-test to maintain alpha=0.01
alpha = 0.01
alpha /= 2 * 2 * 3 # Correct for number of test permutations
alpha /= random_draw.size # Correct for distribution size
cutoff = st.norm().ppf(1 - (alpha / 2))

assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)