Skip to content

Commit 260586b

Browse files
committed
Add moments for CAR distribution
1 parent b6f76e5 commit 260586b

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,13 @@ class CAR(Continuous):
21152115
def dist(cls, mu, W, alpha, tau, *args, **kwargs):
21162116
return super().dist([mu, W, alpha, tau], **kwargs)
21172117

2118+
def get_moment(rv, size, mu, W, alpha, tau):
2119+
moment = mu
2120+
if not rv_size_is_none(size):
2121+
moment_size = at.concatenate([size, moment.shape])
2122+
moment = at.full(moment_size, mu)
2123+
return moment
2124+
21182125
def logp(value, mu, W, alpha, tau):
21192126
"""
21202127
Calculate log-probability of a CAR-distributed vector

pymc/tests/test_distributions_moments.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pymc as pm
1010

1111
from pymc.distributions import (
12+
CAR,
1213
AsymmetricLaplace,
1314
Bernoulli,
1415
Beta,
@@ -109,7 +110,6 @@ def test_all_distributions_have_moments():
109110
# Distributions that have been refactored but don't yet have moments
110111
not_implemented |= {
111112
dist_module.discrete.DiscreteWeibull,
112-
dist_module.multivariate.CAR,
113113
dist_module.multivariate.DirichletMultinomial,
114114
dist_module.multivariate.Wishart,
115115
}
@@ -932,6 +932,34 @@ def test_mv_normal_moment(mu, cov, size, expected):
932932
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
933933

934934

935+
@pytest.mark.parametrize(
936+
"mu, size, expected",
937+
[
938+
(
939+
np.array([1, 0, 3.0, 4]),
940+
None,
941+
np.array([1, 0, 3.0, 4]),
942+
),
943+
(np.array([1, 0, 3.0, 4]), 6, np.full((6, 4), [1, 0, 3.0, 4])),
944+
(np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])),
945+
(
946+
np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
947+
(4, 5),
948+
np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
949+
),
950+
],
951+
)
952+
def test_car_moment(mu, size, expected):
953+
W = np.array(
954+
[[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]
955+
)
956+
tau = 2
957+
alpha = 0.5
958+
with Model() as model:
959+
CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size)
960+
assert_moment_is_expected(model, expected)
961+
962+
935963
@pytest.mark.parametrize(
936964
"mu, sigma, size, expected",
937965
[

0 commit comments

Comments
 (0)