diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 70c50502f7..18bf9dcb1c 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2107,10 +2107,9 @@ def dist(cls, beta, *args, **kwargs): return super().dist([0.0, beta], **kwargs) def get_moment(rv, size, loc, beta): - mean = beta if not rv_size_is_none(size): - mean = at.full(size, mean) - return mean + beta = at.full(size, beta) + return beta def logcdf(value, loc, beta): """ diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 31f2dd7d49..1ea75ad4ec 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -819,6 +819,12 @@ def dist(cls, p, *args, **kwargs): p = at.as_tensor_variable(floatX(p)) return super().dist([p], *args, **kwargs) + def get_moment(rv, size, p): + mean = at.round(1.0 / p) + if not rv_size_is_none(size): + mean = at.full(size, mean) + return mean + def logp(value, p): r""" Calculate log-probability of Geometric distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index ca5337c372..f091e12c7f 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -3,16 +3,19 @@ from scipy import special -from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform from pymc.distributions import ( + Bernoulli, Beta, Binomial, Cauchy, ChiSquared, Constant, Exponential, + Flat, Gamma, + Geometric, HalfCauchy, + HalfFlat, HalfNormal, HalfStudentT, Kumaraswamy, @@ -20,8 +23,11 @@ Logistic, LogNormal, NegativeBinomial, + Normal, Poisson, StudentT, + TruncatedNormal, + Uniform, Weibull, ZeroInflatedBinomial, ZeroInflatedPoisson, @@ -482,3 +488,18 @@ def test_logistic_moment(mu, s, size, expected): with Model() as model: Logistic("x", mu=mu, s=s, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "p, size, expected", + [ + (0.5, None, 2), + (0.2, 5, 5 * np.ones(5)), + (np.linspace(0.25, 1, 4), None, [4, 2, 1, 1]), + (np.linspace(0.25, 1, 4), (2, 4), np.full((2, 4), [4, 2, 1, 1])), + ], +) +def test_geometric_moment(p, size, expected): + with Model() as model: + Geometric("x", p=p, size=size) + assert_moment_is_expected(model, expected)