Skip to content

Commit e0592ec

Browse files
zoj613ricardoV94
authored andcommitted
Add moments for CAR distribution
1 parent d295f3b commit e0592ec

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,13 @@ class CAR(Continuous):
21302130
def dist(cls, mu, W, alpha, tau, *args, **kwargs):
21312131
return super().dist([mu, W, alpha, tau], **kwargs)
21322132

2133+
def get_moment(rv, size, mu, W, alpha, tau):
2134+
moment = mu
2135+
if not rv_size_is_none(size):
2136+
moment_size = at.concatenate([size, moment.shape])
2137+
moment = at.full(moment_size, mu)
2138+
return moment
2139+
21332140
def logp(value, mu, W, alpha, tau):
21342141
"""
21352142
Calculate log-probability of a CAR-distributed vector

pymc/tests/test_distributions_moments.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pymc as pm
1010

1111
from pymc.distributions import (
12+
CAR,
1213
AsymmetricLaplace,
1314
Bernoulli,
1415
Beta,
@@ -110,7 +111,6 @@ def test_all_distributions_have_moments():
110111
# Distributions that have been refactored but don't yet have moments
111112
not_implemented |= {
112113
dist_module.discrete.DiscreteWeibull,
113-
dist_module.multivariate.CAR,
114114
dist_module.multivariate.DirichletMultinomial,
115115
dist_module.multivariate.Wishart,
116116
}
@@ -933,6 +933,34 @@ def test_mv_normal_moment(mu, cov, size, expected):
933933
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
934934

935935

936+
@pytest.mark.parametrize(
937+
"mu, size, expected",
938+
[
939+
(
940+
np.array([1, 0, 3.0, 4]),
941+
None,
942+
np.array([1, 0, 3.0, 4]),
943+
),
944+
(np.array([1, 0, 3.0, 4]), 6, np.full((6, 4), [1, 0, 3.0, 4])),
945+
(np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])),
946+
(
947+
np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
948+
(4, 5),
949+
np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]),
950+
),
951+
],
952+
)
953+
def test_car_moment(mu, size, expected):
954+
W = np.array(
955+
[[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]]
956+
)
957+
tau = 2
958+
alpha = 0.5
959+
with Model() as model:
960+
CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size)
961+
assert_moment_is_expected(model, expected)
962+
963+
936964
@pytest.mark.parametrize(
937965
"mu, sigma, size, expected",
938966
[

0 commit comments

Comments
 (0)