Skip to content

Commit 0d86bdd

Browse files
committed
Update MatrixNormal shape_inference and remove shape and size restrictions
1 parent 0dfe7f9 commit 0d86bdd

File tree

3 files changed

+13
-54
lines changed

3 files changed

+13
-54
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,7 @@ class MatrixNormalRV(RandomVariable):
16161616
_print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}")
16171617

16181618
def _infer_shape(self, size, dist_params, param_shapes=None):
1619-
shape = tuple(size) + tuple(dist_params[0].shape)
1619+
shape = tuple(size) + tuple(dist_params[0].shape[-2:])
16201620
return shape
16211621

16221622
@classmethod
@@ -1746,18 +1746,6 @@ def dist(
17461746

17471747
cholesky = Cholesky(lower=True, on_error="raise")
17481748

1749-
if kwargs.get("size", None) is not None:
1750-
raise NotImplementedError("MatrixNormal doesn't support size argument")
1751-
1752-
if "shape" in kwargs:
1753-
kwargs.pop("shape")
1754-
warnings.warn(
1755-
"The shape argument in MatrixNormal is deprecated and will be ignored."
1756-
"MatrixNormal automatically derives the shape"
1757-
"from row and column matrix dimensions.",
1758-
FutureWarning,
1759-
)
1760-
17611749
# Among-row matrices
17621750
if len([i for i in [rowcov, rowchol] if i is not None]) != 1:
17631751
raise ValueError(
@@ -1787,22 +1775,16 @@ def dist(
17871775
raise ValueError("colchol must be two dimensional.")
17881776
colchol_cov = at.as_tensor_variable(colchol)
17891777

1790-
dist_shape = (rowchol_cov.shape[0], colchol_cov.shape[0])
1778+
dist_shape = (rowchol_cov.shape[-1], colchol_cov.shape[-1])
17911779

17921780
# Broadcasting mu
17931781
mu = at.extra_ops.broadcast_to(mu, shape=dist_shape)
1794-
17951782
mu = at.as_tensor_variable(floatX(mu))
1796-
# mean = median = mode = mu
17971783

17981784
return super().dist([mu, rowchol_cov, colchol_cov], **kwargs)
17991785

18001786
def get_moment(rv, size, mu, rowchol, colchol):
1801-
output_shape = (rowchol.shape[0], colchol.shape[0])
1802-
if not rv_size_is_none(size):
1803-
output_shape = at.concatenate([size, output_shape])
1804-
moment = at.full(output_shape, mu)
1805-
return moment
1787+
return at.full_like(rv, mu)
18061788

18071789
def logp(value, mu, rowchol, colchol):
18081790
"""

pymc/tests/test_distributions_moments.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,11 +1020,6 @@ def test_mvstudentt_moment(nu, mu, cov, size, expected):
10201020
assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3)
10211021

10221022

1023-
def check_matrixnormal_moment(mu, rowchol, colchol, size, expected):
1024-
with Model() as model:
1025-
MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size)
1026-
1027-
10281023
@pytest.mark.parametrize(
10291024
"alpha, mu, sigma, size, expected",
10301025
[
@@ -1092,11 +1087,12 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
10921087
],
10931088
)
10941089
def test_matrixnormal_moment(mu, rowchol, colchol, size, expected):
1095-
if size is None:
1096-
check_matrixnormal_moment(mu, rowchol, colchol, size, expected)
1097-
else:
1098-
with pytest.raises(NotImplementedError):
1099-
check_matrixnormal_moment(mu, rowchol, colchol, size, expected)
1090+
with Model() as model:
1091+
x = MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size)
1092+
1093+
# MatrixNormal logp is only implemented for 2d values
1094+
check_logp = x.ndim == 2
1095+
assert_moment_is_expected(model, expected, check_finite_logp=check_logp)
11001096

11011097

11021098
@pytest.mark.parametrize(

pymc/tests/test_distributions_random.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,13 +1811,15 @@ class TestMatrixNormal(BaseTestDistributionRandom):
18111811
mu = np.random.random((3, 3))
18121812
row_cov = np.eye(3)
18131813
col_cov = np.eye(3)
1814-
shape = None
1815-
size = None
18161814
pymc_dist_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov}
18171815
expected_rv_op_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov}
18181816

1817+
sizes_to_check = (None, (1,), (2, 4))
1818+
sizes_expected = [(3, 3), (1, 3, 3), (2, 4, 3, 3)]
1819+
18191820
checks_to_run = [
18201821
"check_pymc_params_match_rv_op",
1822+
"check_rv_size",
18211823
"check_draws",
18221824
"check_errors",
18231825
"check_random_variable_prior",
@@ -1858,17 +1860,6 @@ def ref_rand(mu, rowcov, colcov):
18581860
assert p > delta
18591861

18601862
def check_errors(self):
1861-
msg = "MatrixNormal doesn't support size argument"
1862-
with pm.Model():
1863-
with pytest.raises(NotImplementedError, match=msg):
1864-
matrixnormal = pm.MatrixNormal(
1865-
"matnormal",
1866-
mu=np.random.random((3, 3)),
1867-
rowcov=np.eye(3),
1868-
colcov=np.eye(3),
1869-
size=15,
1870-
)
1871-
18721863
with pm.Model():
18731864
matrixnormal = pm.MatrixNormal(
18741865
"matnormal",
@@ -1879,16 +1870,6 @@ def check_errors(self):
18791870
with pytest.raises(ValueError):
18801871
logp(matrixnormal, aesara.tensor.ones((3, 3, 3)))
18811872

1882-
with pm.Model():
1883-
with pytest.warns(FutureWarning):
1884-
matrixnormal = pm.MatrixNormal(
1885-
"matnormal",
1886-
mu=np.random.random((3, 3)),
1887-
rowcov=np.eye(3),
1888-
colcov=np.eye(3),
1889-
shape=15,
1890-
)
1891-
18921873
def check_random_variable_prior(self):
18931874
"""
18941875
This test checks for shape correctness when using MatrixNormal distribution

0 commit comments

Comments
 (0)