Skip to content

Commit 35f9966

Browse files
authored
NegativeBinomial, ZeroInflatedPoisson, ZeroInflatedBinomial moments (#5163)
* zero inflated binomial moment * negative binomial moment * ZIP moment
1 parent bdd4d19 commit 35f9966

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,12 @@ def get_n_p(cls, mu=None, alpha=None, p=None, n=None):
716716

717717
return n, p
718718

719+
def get_moment(rv, size, n, p):
720+
mu = at.floor(n * (1 - p) / p)
721+
if not rv_size_is_none(size):
722+
mu = at.full(size, mu)
723+
return mu
724+
719725
def logp(value, n, p):
720726
r"""
721727
Calculate log-probability of NegativeBinomial distribution at specified value.
@@ -1316,6 +1322,12 @@ def dist(cls, psi, theta, *args, **kwargs):
13161322
theta = at.as_tensor_variable(floatX(theta))
13171323
return super().dist([psi, theta], *args, **kwargs)
13181324

1325+
def get_moment(rv, size, psi, theta):
1326+
mean = at.floor(psi * theta)
1327+
if not rv_size_is_none(size):
1328+
mean = at.full(size, mean)
1329+
return mean
1330+
13191331
def logp(value, psi, theta):
13201332
r"""
13211333
Calculate log-probability of ZeroInflatedPoisson distribution at specified value.
@@ -1449,6 +1461,12 @@ def dist(cls, psi, n, p, *args, **kwargs):
14491461
p = at.as_tensor_variable(floatX(p))
14501462
return super().dist([psi, n, p], *args, **kwargs)
14511463

1464+
def get_moment(rv, size, psi, n, p):
1465+
mean = at.round((1 - psi) * n * p)
1466+
if not rv_size_is_none(size):
1467+
mean = at.full(size, mean)
1468+
return mean
1469+
14521470
def logp(value, psi, n, p):
14531471
r"""
14541472
Calculate log-probability of ZeroInflatedBinomial distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
Laplace,
2020
Logistic,
2121
LogNormal,
22+
NegativeBinomial,
2223
Poisson,
2324
StudentT,
2425
Weibull,
26+
ZeroInflatedBinomial,
27+
ZeroInflatedPoisson,
2528
)
2629
from pymc.distributions.shape_utils import rv_size_is_none
2730
from pymc.initial_point import make_initial_point_fn
@@ -402,6 +405,21 @@ def test_poisson_moment(mu, size, expected):
402405
assert_moment_is_expected(model, expected)
403406

404407

408+
@pytest.mark.parametrize(
409+
"n, p, size, expected",
410+
[
411+
(10, 0.7, None, 4),
412+
(10, 0.7, 5, np.full(5, 4)),
413+
(np.full(3, 10), np.arange(1, 4) / 10, None, np.array([90, 40, 23])),
414+
(10, np.arange(1, 4) / 10, (2, 3), np.full((2, 3), np.array([90, 40, 23]))),
415+
],
416+
)
417+
def test_negative_binomial_moment(n, p, size, expected):
418+
with Model() as model:
419+
NegativeBinomial("x", n=n, p=p, size=size)
420+
assert_moment_is_expected(model, expected)
421+
422+
405423
@pytest.mark.parametrize(
406424
"c, size, expected",
407425
[
@@ -416,6 +434,36 @@ def test_constant_moment(c, size, expected):
416434
assert_moment_is_expected(model, expected)
417435

418436

437+
@pytest.mark.parametrize(
438+
"psi, theta, size, expected",
439+
[
440+
(0.9, 3.0, None, 2),
441+
(0.8, 2.9, 5, np.full(5, 2)),
442+
(0.2, np.arange(1, 5) * 5, None, np.arange(1, 5)),
443+
(0.2, np.arange(1, 5) * 5, (2, 4), np.full((2, 4), np.arange(1, 5))),
444+
],
445+
)
446+
def test_zero_inflated_poisson_moment(psi, theta, size, expected):
447+
with Model() as model:
448+
ZeroInflatedPoisson("x", psi=psi, theta=theta, size=size)
449+
assert_moment_is_expected(model, expected)
450+
451+
452+
@pytest.mark.parametrize(
453+
"psi, n, p, size, expected",
454+
[
455+
(0.2, 7, 0.7, None, 4),
456+
(0.2, 7, 0.3, 5, np.full(5, 2)),
457+
(0.6, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)),
458+
(0.6, 25, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))),
459+
],
460+
)
461+
def test_zero_inflated_binomial_moment(psi, n, p, size, expected):
462+
with Model() as model:
463+
ZeroInflatedBinomial("x", psi=psi, n=n, p=p, size=size)
464+
assert_moment_is_expected(model, expected)
465+
466+
419467
@pytest.mark.parametrize(
420468
"mu, s, size, expected",
421469
[

0 commit comments

Comments
 (0)