Skip to content

Commit a5b13d4

Browse files
sagartomarricardoV94
authored andcommitted
Added moment for ZeroInflatedNegativeBinomial distribution
1 parent 917e95a commit a5b13d4

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,12 @@ def dist(cls, psi, mu, alpha, *args, **kwargs):
16501650
p = at.as_tensor_variable(floatX(p))
16511651
return super().dist([psi, n, p], *args, **kwargs)
16521652

1653+
def get_moment(rv, size, psi, n, p):
1654+
mean = at.floor(psi * n * (1 - p) / p)
1655+
if not rv_size_is_none(size):
1656+
mean = at.full(size, mean)
1657+
return mean
1658+
16531659
def logp(value, psi, n, p):
16541660
r"""
16551661
Calculate log-probability of ZeroInflatedNegativeBinomial distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
Wald,
5353
Weibull,
5454
ZeroInflatedBinomial,
55+
ZeroInflatedNegativeBinomial,
5556
ZeroInflatedPoisson,
5657
)
5758
from pymc.distributions.distribution import get_moment
@@ -1052,3 +1053,24 @@ def test_polyagamma_moment(h, z, size, expected):
10521053
with Model() as model:
10531054
PolyaGamma("x", h=h, z=z, size=size)
10541055
assert_moment_is_expected(model, expected)
1056+
1057+
1058+
@pytest.mark.parametrize(
1059+
"psi, mu, alpha, size, expected",
1060+
[
1061+
(0.2, 10, 3, None, 2),
1062+
(0.2, 10, 4, 5, np.full(5, 2)),
1063+
(0.4, np.arange(1, 5), np.arange(2, 6), None, np.array([0, 0, 1, 1])),
1064+
(
1065+
np.linspace(0.2, 0.6, 3),
1066+
np.arange(1, 10, 4),
1067+
np.arange(1, 4),
1068+
(2, 3),
1069+
np.full((2, 3), np.array([0, 2, 5])),
1070+
),
1071+
],
1072+
)
1073+
def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected):
1074+
with Model() as model:
1075+
ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size)
1076+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)