From ed36754497e6c7a03d66f061ba9fd003762ba5c7 Mon Sep 17 00:00:00 2001 From: michaeloriordan Date: Fri, 5 Nov 2021 21:10:51 +0000 Subject: [PATCH 1/4] Add Beta mean and tests --- pymc/distributions/continuous.py | 7 +++++++ pymc/tests/test_distributions_moments.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index ac1598d759..66755aeb32 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1190,6 +1190,13 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar return super().dist([alpha, beta], **kwargs) + def get_moment(rv, size, alpha, beta): + alpha, beta = at.broadcast_arrays(alpha, beta) + mean = alpha / (alpha + beta) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + @classmethod def get_alpha_beta(self, alpha=None, beta=None, mu=None, sigma=None): if (alpha is not None) and (beta is not None): diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 0648f4b8fe..6b5566e3f9 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -2,7 +2,7 @@ import pytest from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform -from pymc.distributions import HalfNormal +from pymc.distributions import Beta, HalfNormal from pymc.distributions.shape_utils import rv_size_is_none from pymc.initial_point import make_initial_point_fn from pymc.model import Model From 09b03fd7585d32a14dfbdb18edf7fa8307b6f309 Mon Sep 17 00:00:00 2001 From: michaeloriordan Date: Fri, 5 Nov 2021 21:17:31 +0000 Subject: [PATCH 2/4] Test beta moment --- pymc/tests/test_distributions_moments.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 6b5566e3f9..37a69aebad 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -142,3 +142,19 @@ def test_bernoulli_moment(p, size, expected): with Model() as model: Bernoulli("x", p=p, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "alpha, beta, size, expected", + [ + (1, 1, None, 0.5), + (1, 1, 5, np.full(5, 0.5)), + (1, np.arange(1, 6), None, 1 / np.arange(2, 7)), + (1, np.arange(1, 6), (2, 5), np.full((2, 5), 1 / np.arange(2, 7))), + ], +) +def test_beta_moment(alpha, beta, size, expected): + with Model() as model: + Beta("x", alpha=alpha, beta=beta, size=size) + assert_moment_is_expected(model, expected) + From c5f2fe70c2a11851ab066e2ce4887330a25bb6cd Mon Sep 17 00:00:00 2001 From: Michael O' Riordan Date: Fri, 5 Nov 2021 21:55:19 +0000 Subject: [PATCH 3/4] Remove unnecessary line Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/continuous.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 66755aeb32..88cdfc8690 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1191,7 +1191,6 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar return super().dist([alpha, beta], **kwargs) def get_moment(rv, size, alpha, beta): - alpha, beta = at.broadcast_arrays(alpha, beta) mean = alpha / (alpha + beta) if not rv_size_is_none(size): mean = at.full(size, mean) From d227d76530ad8996eadbcf49d47bb3d0294f81e6 Mon Sep 17 00:00:00 2001 From: michaeloriordan Date: Fri, 5 Nov 2021 22:18:12 +0000 Subject: [PATCH 4/4] Pre-commit fix end of file --- pymc/tests/test_distributions_moments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 37a69aebad..26a76640ef 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -157,4 +157,3 @@ def test_beta_moment(alpha, beta, size, expected): with Model() as model: Beta("x", alpha=alpha, beta=beta, size=size) assert_moment_is_expected(model, expected) -