diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 05796c9589..e0484a01fc 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1462,7 +1462,7 @@ class ZeroInflatedBinomial(Discrete): ======== ========================== Support :math:`x \in \mathbb{N}_0` - Mean :math:`(1 - \psi) n p` + Mean :math:`\psi n p` Variance :math:`(1-\psi) n p [1 - p(1 - \psi n)].` ======== ========================== @@ -1487,7 +1487,7 @@ def dist(cls, psi, n, p, *args, **kwargs): return super().dist([psi, n, p], *args, **kwargs) def get_moment(rv, size, psi, n, p): - mean = at.round((1 - psi) * n * p) + mean = at.round(psi * n * p) if not rv_size_is_none(size): mean = at.full(size, mean) return mean @@ -1650,6 +1650,12 @@ def dist(cls, psi, mu, alpha, *args, **kwargs): p = at.as_tensor_variable(floatX(p)) return super().dist([psi, n, p], *args, **kwargs) + def get_moment(rv, size, psi, n, p): + mean = at.floor(psi * n * (1 - p) / p) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, psi, n, p): r""" Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index e700985a31..050755788b 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -52,6 +52,7 @@ Wald, Weibull, ZeroInflatedBinomial, + ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, ) from pymc.distributions.distribution import get_moment @@ -553,11 +554,11 @@ def test_zero_inflated_poisson_moment(psi, theta, size, expected): @pytest.mark.parametrize( "psi, n, p, size, expected", [ - (0.2, 7, 0.7, None, 4), - (0.2, 7, 0.3, 5, np.full(5, 2)), - (0.6, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)), + (0.8, 7, 0.7, None, 4), + (0.8, 7, 0.3, 5, np.full(5, 2)), + (0.4, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)), ( - 0.6, + 0.4, 25, np.arange(1, 6) / 10, (2, 5), @@ -1052,3 +1053,24 @@ def test_polyagamma_moment(h, z, size, expected): with Model() as model: PolyaGamma("x", h=h, z=z, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "psi, mu, alpha, size, expected", + [ + (0.2, 10, 3, None, 2), + (0.2, 10, 4, 5, np.full(5, 2)), + (0.4, np.arange(1, 5), np.arange(2, 6), None, np.array([0, 0, 1, 1])), + ( + np.linspace(0.2, 0.6, 3), + np.arange(1, 10, 4), + np.arange(1, 4), + (2, 3), + np.full((2, 3), np.array([0, 2, 5])), + ), + ], +) +def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected): + with Model() as model: + ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size) + assert_moment_is_expected(model, expected)