Skip to content

Commit 45d48a6

Browse files
Binomial and Poisson Moment (#5150)
Co-authored-by: Farhan Reynaldo <[email protected]>
1 parent 4da614d commit 45d48a6

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

pymc/distributions/discrete.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,14 @@ class Binomial(Discrete):
114114
def dist(cls, n, p, *args, **kwargs):
115115
n = at.as_tensor_variable(intX(n))
116116
p = at.as_tensor_variable(floatX(p))
117-
# mode = at.cast(tround(n * p), self.dtype)
118117
return super().dist([n, p], **kwargs)
119118

119+
def get_moment(rv, size, n, p):
120+
mean = at.round(n * p)
121+
if not rv_size_is_none(size):
122+
mean = at.full(size, mean)
123+
return mean
124+
120125
def logp(value, n, p):
121126
r"""
122127
Calculate log-probability of Binomial distribution at specified value.
@@ -567,9 +572,14 @@ class Poisson(Discrete):
567572
@classmethod
568573
def dist(cls, mu, *args, **kwargs):
569574
mu = at.as_tensor_variable(floatX(mu))
570-
# mode = intX(at.floor(mu))
571575
return super().dist([mu], *args, **kwargs)
572576

577+
def get_moment(rv, size, mu):
578+
mu = at.floor(mu)
579+
if not rv_size_is_none(size):
580+
mu = at.full(size, mu)
581+
return mu
582+
573583
def logp(value, mu):
574584
r"""
575585
Calculate log-probability of Poisson distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform
77
from pymc.distributions import (
88
Beta,
9+
Binomial,
910
Cauchy,
1011
ChiSquared,
1112
Exponential,
@@ -15,6 +16,7 @@
1516
Kumaraswamy,
1617
Laplace,
1718
LogNormal,
19+
Poisson,
1820
StudentT,
1921
Weibull,
2022
)
@@ -224,7 +226,13 @@ def test_laplace_moment(mu, b, size, expected):
224226
(0, 1, 1, None, 0),
225227
(0, np.ones(5), 1, None, np.zeros(5)),
226228
(np.arange(5), 10, np.arange(1, 6), None, np.arange(5)),
227-
(np.arange(5), 10, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
229+
(
230+
np.arange(5),
231+
10,
232+
np.arange(1, 6),
233+
(2, 5),
234+
np.full((2, 5), np.arange(5)),
235+
),
228236
],
229237
)
230238
def test_studentt_moment(mu, nu, sigma, size, expected):
@@ -333,11 +341,44 @@ def test_gamma_moment(alpha, beta, size, expected):
333341
np.arange(1, 6),
334342
np.arange(2, 7),
335343
(2, 5),
336-
np.full((2, 5), np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6))),
344+
np.full(
345+
(2, 5),
346+
np.arange(2, 7) * special.gamma(1 + 1 / np.arange(1, 6)),
347+
),
337348
),
338349
],
339350
)
340351
def test_weibull_moment(alpha, beta, size, expected):
341352
with Model() as model:
342353
Weibull("x", alpha=alpha, beta=beta, size=size)
343354
assert_moment_is_expected(model, expected)
355+
356+
357+
@pytest.mark.parametrize(
358+
"n, p, size, expected",
359+
[
360+
(7, 0.7, None, 5),
361+
(7, 0.3, 5, np.full(5, 2)),
362+
(10, np.arange(1, 6) / 10, None, np.arange(1, 6)),
363+
(10, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))),
364+
],
365+
)
366+
def test_binomial_moment(n, p, size, expected):
367+
with Model() as model:
368+
Binomial("x", n=n, p=p, size=size)
369+
assert_moment_is_expected(model, expected)
370+
371+
372+
@pytest.mark.parametrize(
373+
"mu, size, expected",
374+
[
375+
(2.7, None, 2),
376+
(2.3, 5, np.full(5, 2)),
377+
(np.arange(1, 5), None, np.arange(1, 5)),
378+
(np.arange(1, 5), (2, 4), np.full((2, 4), np.arange(1, 5))),
379+
],
380+
)
381+
def test_poisson_moment(mu, size, expected):
382+
with Model() as model:
383+
Poisson("x", mu=mu, size=size)
384+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)