diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f0f031ff02..96a91a336a 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2115,6 +2115,13 @@ class CAR(Continuous): def dist(cls, mu, W, alpha, tau, *args, **kwargs): return super().dist([mu, W, alpha, tau], **kwargs) + def get_moment(rv, size, mu, W, alpha, tau): + moment = mu + if not rv_size_is_none(size): + moment_size = at.concatenate([size, moment.shape]) + moment = at.full(moment_size, mu) + return moment + def logp(value, mu, W, alpha, tau): """ Calculate log-probability of a CAR-distributed vector diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 4a6361504e..4471b5c5eb 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -9,6 +9,7 @@ import pymc as pm from pymc.distributions import ( + CAR, AsymmetricLaplace, Bernoulli, Beta, @@ -109,7 +110,6 @@ def test_all_distributions_have_moments(): # Distributions that have been refactored but don't yet have moments not_implemented |= { dist_module.discrete.DiscreteWeibull, - dist_module.multivariate.CAR, dist_module.multivariate.DirichletMultinomial, dist_module.multivariate.Wishart, } @@ -932,6 +932,34 @@ def test_mv_normal_moment(mu, cov, size, expected): assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) +@pytest.mark.parametrize( + "mu, size, expected", + [ + ( + np.array([1, 0, 3.0, 4]), + None, + np.array([1, 0, 3.0, 4]), + ), + (np.array([1, 0, 3.0, 4]), 6, np.full((6, 4), [1, 0, 3.0, 4])), + (np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])), + ( + np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), + (4, 5), + np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), + ), + ], +) +def test_car_moment(mu, size, expected): + W = np.array( + [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] + ) + tau = 2 + alpha = 0.5 + with Model() as model: + CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "mu, sigma, size, expected", [