Skip to content

Commit 75c52fc

Browse files
add categorical moment
1 parent 275c145 commit 75c52fc

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

pymc/distributions/discrete.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,13 +1161,14 @@ class Categorical(Discrete):
11611161
def dist(cls, p, **kwargs):
11621162

11631163
p = at.as_tensor_variable(floatX(p))
1164-
1165-
# mode = at.argmax(p, axis=-1)
1166-
# if mode.ndim == 1:
1167-
# mode = at.squeeze(mode)
1168-
11691164
return super().dist([p], **kwargs)
11701165

1166+
def get_moment(rv, size, p):
1167+
mode = at.argmax(p, axis=-1)
1168+
if not rv_size_is_none(size):
1169+
mode = at.full(size, mode)
1170+
return mode
1171+
11711172
def logp(value, p):
11721173
r"""
11731174
Calculate log-probability of Categorical distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Bernoulli,
88
Beta,
99
Binomial,
10+
Categorical,
1011
Cauchy,
1112
ChiSquared,
1213
Constant,
@@ -612,3 +613,22 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
612613
with Model() as model:
613614
DiscreteUniform("x", lower=lower, upper=upper, size=size)
614615
assert_moment_is_expected(model, expected)
616+
617+
618+
@pytest.mark.parametrize(
619+
"p, size, expected",
620+
[
621+
(np.arange(0.1, 0.4, 0.1), None, 3),
622+
(np.arange(0.1, 0.4, 0.1), 5, np.full(5, 3)),
623+
(np.full((2, 4), np.arange(0.1, 0.4, 0.1)), None, [3, 3]),
624+
(
625+
np.full((2, 4), np.arange(0.1, 0.4, 0.1)),
626+
(3, 2),
627+
np.full((3, 2), [3, 3]),
628+
),
629+
],
630+
)
631+
def test_categorical_moment(p, size, expected):
632+
with Model() as model:
633+
Categorical("x", p=p, size=size)
634+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)