diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index ef81c68281..70c50502f7 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -3388,6 +3388,12 @@ def dist(cls, mu=0.0, s=1.0, *args, **kwargs): s = at.as_tensor_variable(floatX(s)) return super().dist([mu, s], *args, **kwargs) + def get_moment(rv, size, mu, s): + mu, _ = at.broadcast_arrays(mu, s) + if not rv_size_is_none(size): + mu = at.full(size, mu) + return mu + def logcdf(value, mu, s): r""" Compute the log of the cumulative distribution function for Logistic distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 591b8decdf..9e2d56afdb 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -17,6 +17,7 @@ HalfStudentT, Kumaraswamy, Laplace, + Logistic, LogNormal, Poisson, StudentT, @@ -413,3 +414,23 @@ def test_constant_moment(c, size, expected): with Model() as model: Constant("x", c=c, size=size) assert_moment_is_expected(model, expected) + + +@pytest.mark.parametrize( + "mu, s, size, expected", + [ + (1, 1, None, 1), + (1, 1, 5, np.full(5, 1)), + (2, np.arange(1, 6), None, np.full(5, 2)), + ( + np.arange(1, 6), + np.arange(1, 6), + (2, 5), + np.full((2, 5), np.arange(1, 6)), + ), + ], +) +def test_logistic_moment(mu, s, size, expected): + with Model() as model: + Logistic("x", mu=mu, s=s, size=size) + assert_moment_is_expected(model, expected)