diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 8b42de7578..0980356432 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2779,15 +2779,18 @@ def dist(cls, mu=0.0, sigma=None, nu=None, sd=None, *args, **kwargs): sigma = at.as_tensor_variable(floatX(sigma)) nu = at.as_tensor_variable(floatX(nu)) - # sd = sigma - # mean = mu + nu - # variance = (sigma ** 2) + (nu ** 2) - assert_negative_support(sigma, "sigma", "ExGaussian") assert_negative_support(nu, "nu", "ExGaussian") return super().dist([mu, sigma, nu], *args, **kwargs) + def get_moment(rv, size, mu, sigma, nu): + mu, nu, _ = at.broadcast_arrays(mu, nu, sigma) + moment = mu + nu + if not rv_size_is_none(size): + moment = at.full(size, moment) + return moment + def logp(value, mu, sigma, nu): """ Calculate log-probability of ExGaussian distribution at specified value. diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index ce9fec627f..f52ca8baa0 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -11,6 +11,7 @@ ChiSquared, Constant, DiscreteUniform, + ExGaussian, Exponential, Flat, Gamma, @@ -541,6 +542,22 @@ def test_logistic_moment(mu, s, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "mu, nu, sigma, size, expected", + [ + (1, 1, None, None, 2), + (1, 1, np.ones((2, 5)), None, np.full([2, 5], 2)), + (1, 1, None, 5, np.full(5, 2)), + (1, np.arange(1, 6), None, None, np.arange(2, 7)), + (1, np.arange(1, 6), None, (2, 5), np.full((2, 5), np.arange(2, 7))), + ], +) +def test_exgaussian_moment(mu, nu, sigma, size, expected): + with Model() as model: + ExGaussian("x", mu=mu, sigma=sigma, nu=nu, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "p, size, expected", [