Skip to content

Commit f24a1df

Browse files
Discrete uniform and hyper geometric moment (#5167)
Co-authored-by: Farhan Reynaldo <[email protected]>
1 parent 7745f55 commit f24a1df

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

pymc/distributions/discrete.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,13 @@ def dist(cls, N, k, n, *args, **kwargs):
926926
n = at.as_tensor_variable(intX(n))
927927
return super().dist([good, bad, n], *args, **kwargs)
928928

929+
def get_moment(rv, size, good, bad, n):
930+
N, k = good + bad, good
931+
mode = at.floor((n + 1) * (k + 1) / (N + 2))
932+
if not rv_size_is_none(size):
933+
mode = at.full(size, mode)
934+
return mode
935+
929936
def logp(value, good, bad, n):
930937
r"""
931938
Calculate log-probability of HyperGeometric distribution at specified value.
@@ -1060,6 +1067,12 @@ def dist(cls, lower, upper, *args, **kwargs):
10601067
upper = intX(at.floor(upper))
10611068
return super().dist([lower, upper], **kwargs)
10621069

1070+
def get_moment(rv, size, lower, upper):
1071+
mode = at.maximum(at.floor((upper + lower) / 2.0), lower)
1072+
if not rv_size_is_none(size):
1073+
mode = at.full(size, mode)
1074+
return mode
1075+
10631076
def logp(value, lower, upper):
10641077
r"""
10651078
Calculate log-probability of DiscreteUniform distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Cauchy,
1111
ChiSquared,
1212
Constant,
13+
DiscreteUniform,
1314
Exponential,
1415
Flat,
1516
Gamma,
@@ -18,6 +19,7 @@
1819
HalfFlat,
1920
HalfNormal,
2021
HalfStudentT,
22+
HyperGeometric,
2123
Kumaraswamy,
2224
Laplace,
2325
Logistic,
@@ -417,7 +419,12 @@ def test_poisson_moment(mu, size, expected):
417419
(10, 0.7, None, 4),
418420
(10, 0.7, 5, np.full(5, 4)),
419421
(np.full(3, 10), np.arange(1, 4) / 10, None, np.array([90, 40, 23])),
420-
(10, np.arange(1, 4) / 10, (2, 3), np.full((2, 3), np.array([90, 40, 23]))),
422+
(
423+
10,
424+
np.arange(1, 4) / 10,
425+
(2, 3),
426+
np.full((2, 3), np.array([90, 40, 23])),
427+
),
421428
],
422429
)
423430
def test_negative_binomial_moment(n, p, size, expected):
@@ -461,7 +468,13 @@ def test_zero_inflated_poisson_moment(psi, theta, size, expected):
461468
(0.2, 7, 0.7, None, 4),
462469
(0.2, 7, 0.3, 5, np.full(5, 2)),
463470
(0.6, 25, np.arange(1, 6) / 10, None, np.arange(1, 6)),
464-
(0.6, 25, np.arange(1, 6) / 10, (2, 5), np.full((2, 5), np.arange(1, 6))),
471+
(
472+
0.6,
473+
25,
474+
np.arange(1, 6) / 10,
475+
(2, 5),
476+
np.full((2, 5), np.arange(1, 6)),
477+
),
465478
],
466479
)
467480
def test_zero_inflated_binomial_moment(psi, n, p, size, expected):
@@ -503,3 +516,44 @@ def test_geometric_moment(p, size, expected):
503516
with Model() as model:
504517
Geometric("x", p=p, size=size)
505518
assert_moment_is_expected(model, expected)
519+
520+
521+
@pytest.mark.parametrize(
522+
"N, k, n, size, expected",
523+
[
524+
(50, 10, 20, None, 4),
525+
(50, 10, 23, 5, np.full(5, 5)),
526+
(50, 10, np.arange(23, 28), None, np.full(5, 5)),
527+
(
528+
50,
529+
10,
530+
np.arange(18, 23),
531+
(2, 5),
532+
np.full((2, 5), 4),
533+
),
534+
],
535+
)
536+
def test_hyper_geometric_moment(N, k, n, size, expected):
537+
with Model() as model:
538+
HyperGeometric("x", N=N, k=k, n=n, size=size)
539+
assert_moment_is_expected(model, expected)
540+
541+
542+
@pytest.mark.parametrize(
543+
"lower, upper, size, expected",
544+
[
545+
(1, 5, None, 3),
546+
(1, 5, 5, np.full(5, 3)),
547+
(1, np.arange(5, 22, 4), None, np.arange(3, 13, 2)),
548+
(
549+
1,
550+
np.arange(5, 22, 4),
551+
(2, 5),
552+
np.full((2, 5), np.arange(3, 13, 2)),
553+
),
554+
],
555+
)
556+
def test_discrete_uniform_moment(lower, upper, size, expected):
557+
with Model() as model:
558+
DiscreteUniform("x", lower=lower, upper=upper, size=size)
559+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)