diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 08a848de6f..19d130905d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -228,6 +228,13 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs): cov = quaddist_matrix(cov, chol, tau, lower) return super().dist([mu, cov], **kwargs) + def get_moment(rv, size, mu, cov): + moment = mu + if not rv_size_is_none(size): + moment_size = at.concatenate([size, mu.shape]) + moment = at.full(moment_size, mu) + return moment + def logp(value, mu, cov): """ Calculate log-probability of Multivariate Normal distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 13f3f78e9e..c2ec034af0 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -44,6 +44,7 @@ ZeroInflatedBinomial, ZeroInflatedPoisson, ) +from pymc.distributions.multivariate import MvNormal from pymc.distributions.shape_utils import rv_size_is_none from pymc.initial_point import make_initial_point_fn from pymc.model import Model @@ -751,6 +752,40 @@ def test_categorical_moment(p, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "mu, cov, size, expected", + [ + (np.ones(1), np.identity(1), None, np.ones(1)), + (np.ones(3), np.identity(3), None, np.ones(3)), + (np.ones((2, 2)), np.identity(2), None, np.ones((2, 2))), + (np.array([1, 0, 3.0]), np.identity(3), None, np.array([1, 0, 3.0])), + (np.array([1, 0, 3.0]), np.identity(3), (4, 2), np.full((4, 2, 3), [1, 0, 3.0])), + ( + np.array([1, 3.0]), + np.identity(2), + 5, + np.full((5, 2), [1, 3.0]), + ), + ( + np.array([1, 3.0]), + np.array([[1.0, 0.5], [0.5, 2]]), + (4, 5), + np.full((4, 5, 2), [1, 3.0]), + ), + ( + np.array([[3.0, 5], [1, 4]]), + np.identity(2), + (4, 5), + np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]), + ), + ], +) +def test_mv_normal_moment(mu, cov, size, expected): + with Model() as model: + MvNormal("x", mu=mu, cov=cov, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "mu, sigma, size, expected", [