diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index ac1598d759..88cdfc8690 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1190,6 +1190,12 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar return super().dist([alpha, beta], **kwargs) + def get_moment(rv, size, alpha, beta): + mean = alpha / (alpha + beta) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + @classmethod def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None): if (alpha is not None) and (beta is not None): diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 0648f4b8fe..26a76640ef 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -2,7 +2,7 @@ import pytest from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform -from pymc.distributions import HalfNormal +from pymc.distributions import Beta, HalfNormal from pymc.distributions.shape_utils import rv_size_is_none from pymc.initial_point import make_initial_point_fn from pymc.model import Model @@ -142,3 +142,18 @@ def test_bernoulli_moment(p, size, expected): with Model() as model: Bernoulli("x", p=p, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "alpha, beta, size, expected", + [ + (1, 1, None, 0.5), + (1, 1, 5, np.full(5, 0.5)), + (1, np.arange(1, 6), None, 1 / np.arange(2, 7)), + (1, np.arange(1, 6), (2, 5), np.full((2, 5), 1 / np.arange(2, 7))), + ], +) +def test_beta_moment(alpha, beta, size, expected): + with Model() as model: + Beta("x", alpha=alpha, beta=beta, size=size) + assert_moment_is_expected(model, expected)