|
9 | 9 | import pymc as pm
|
10 | 10 |
|
11 | 11 | from pymc.distributions import (
|
| 12 | + CAR, |
12 | 13 | AsymmetricLaplace,
|
13 | 14 | Bernoulli,
|
14 | 15 | Beta,
|
@@ -109,7 +110,6 @@ def test_all_distributions_have_moments():
|
109 | 110 | # Distributions that have been refactored but don't yet have moments
|
110 | 111 | not_implemented |= {
|
111 | 112 | dist_module.discrete.DiscreteWeibull,
|
112 |
| - dist_module.multivariate.CAR, |
113 | 113 | dist_module.multivariate.DirichletMultinomial,
|
114 | 114 | dist_module.multivariate.Wishart,
|
115 | 115 | }
|
@@ -932,6 +932,34 @@ def test_mv_normal_moment(mu, cov, size, expected):
|
932 | 932 | assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
|
933 | 933 |
|
934 | 934 |
|
| 935 | +@pytest.mark.parametrize( |
| 936 | + "mu, size, expected", |
| 937 | + [ |
| 938 | + ( |
| 939 | + np.array([1, 0, 3.0, 4]), |
| 940 | + None, |
| 941 | + np.array([1, 0, 3.0, 4]), |
| 942 | + ), |
| 943 | + (np.array([1, 0, 3.0, 4]), 6, np.full((6, 4), [1, 0, 3.0, 4])), |
| 944 | + (np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])), |
| 945 | + ( |
| 946 | + np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), |
| 947 | + (4, 5), |
| 948 | + np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), |
| 949 | + ), |
| 950 | + ], |
| 951 | +) |
| 952 | +def test_car_moment(mu, size, expected): |
| 953 | + W = np.array( |
| 954 | + [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] |
| 955 | + ) |
| 956 | + tau = 2 |
| 957 | + alpha = 0.5 |
| 958 | + with Model() as model: |
| 959 | + CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size) |
| 960 | + assert_moment_is_expected(model, expected) |
| 961 | + |
| 962 | + |
935 | 963 | @pytest.mark.parametrize(
|
936 | 964 | "mu, sigma, size, expected",
|
937 | 965 | [
|
|
0 commit comments