Skip to content

Commit 0dfe7f9

Browse files
committed
Fix Wishart shape inference when there are batched parameters
1 parent 7bb32cf commit 0dfe7f9

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,7 @@ class WishartRV(RandomVariable):
905905

906906
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
907907
# The shape of second parameter `V` defines the shape of the output.
908-
return dist_params[1].shape
908+
return dist_params[1].shape[-2:]
909909

910910
@classmethod
911911
def rng_fn(cls, rng, nu, V, size):

pymc/tests/test_distributions_random.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,8 +1784,25 @@ def wishart_rng_fn(self, size, nu, V, rng):
17841784
"check_rv_size",
17851785
"check_pymc_params_match_rv_op",
17861786
"check_pymc_draws_match_reference",
1787+
"check_rv_size_batched_params",
17871788
]
17881789

1790+
def check_rv_size_batched_params(self):
1791+
for size in (None, (2,), (1, 2), (4, 3, 2)):
1792+
x = pm.Wishart.dist(nu=4, V=np.stack([np.eye(3), np.eye(3)]), size=size)
1793+
1794+
if size is None:
1795+
expected_shape = (2, 3, 3)
1796+
else:
1797+
expected_shape = size + (3, 3)
1798+
1799+
assert tuple(x.shape.eval()) == expected_shape
1800+
1801+
# RNG does not currently support batched parameters, whet it does this test
1802+
# should be updated to check that draws also have the expected shape
1803+
with pytest.raises(ValueError):
1804+
x.eval()
1805+
17891806

17901807
class TestMatrixNormal(BaseTestDistributionRandom):
17911808

0 commit comments

Comments
 (0)