Skip to content

Commit 93c2293

Browse files
Add Constant Moment (#5156)
Co-authored-by: shivam15s <[email protected]>
1 parent 741f207 commit 93c2293

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

pymc/distributions/discrete.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,11 @@ def dist(cls, c, *args, **kwargs):
12171217
c = at.as_tensor_variable(floatX(c))
12181218
return super().dist([c], **kwargs)
12191219

1220+
def get_moment(rv, size, c):
1221+
if not rv_size_is_none(size):
1222+
c = at.full(size, c)
1223+
return c
1224+
12201225
def logp(value, c):
12211226
r"""
12221227
Calculate log-probability of Constant distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Binomial,
1010
Cauchy,
1111
ChiSquared,
12+
Constant,
1213
Exponential,
1314
Gamma,
1415
HalfCauchy,
@@ -398,3 +399,17 @@ def test_poisson_moment(mu, size, expected):
398399
with Model() as model:
399400
Poisson("x", mu=mu, size=size)
400401
assert_moment_is_expected(model, expected)
402+
403+
404+
@pytest.mark.parametrize(
405+
"c, size, expected",
406+
[
407+
(1, None, 1),
408+
(1, 5, np.full(5, 1)),
409+
(np.arange(1, 6), None, np.arange(1, 6)),
410+
],
411+
)
412+
def test_constant_moment(c, size, expected):
413+
with Model() as model:
414+
Constant("x", c=c, size=size)
415+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)