diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d8a8c2f8fd..bb9ff075aa 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1217,6 +1217,11 @@ def dist(cls, c, *args, **kwargs): c = at.as_tensor_variable(floatX(c)) return super().dist([c], **kwargs) + def get_moment(rv, size, c): + if not rv_size_is_none(size): + c = at.full(size, c) + return c + def logp(value, c): r""" Calculate log-probability of Constant distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index f03e1b6f9b..ae40b3bbbb 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -9,6 +9,7 @@ Binomial, Cauchy, ChiSquared, + Constant, Exponential, Gamma, HalfCauchy, @@ -382,3 +383,17 @@ def test_poisson_moment(mu, size, expected): with Model() as model: Poisson("x", mu=mu, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "c, size, expected", + [ + (1, None, 1), + (1, 5, np.full(5, 1)), + (np.arange(1, 6), None, np.arange(1, 6)), + ], +) +def test_constant_moment(c, size, expected): + with Model() as model: + Constant("x", c=c, size=size) + assert_moment_is_expected(model, expected)