diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 746400c9a2..c972755df8 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -262,6 +262,13 @@ class MvStudentTRV(RandomVariable): dtype = "floatX" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") + def make_node(self, rng, size, dtype, nu, mu, cov): + nu = at.as_tensor_variable(nu) + if not nu.ndim == 0: + raise ValueError("nu must be a scalar (ndim=0).") + + return super().make_node(rng, size, dtype, nu, mu, cov) + def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 0216de75fe..1e9788dd4e 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import itertools +import re from typing import Callable, List, Optional @@ -1115,8 +1116,20 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): "check_pymc_params_match_rv_op", "check_pymc_draws_match_reference", "check_rv_size", + "test_errors", ] + def test_errors(self): + msg = "nu must be a scalar (ndim=0)." + with pm.Model(): + with pytest.raises(ValueError, match=re.escape(msg)): + mvstudentt = pm.MvStudentT( + "mvstudentt", + nu=np.array([1, 2]), + mu=np.ones(2), + cov=np.full((2, 2), np.ones(2)), + ) + class TestMvStudentTChol(BaseTestDistribution): pymc_dist = pm.MvStudentT