From 38ff47161c5479a5c409d7fdf7fa5735b873254a Mon Sep 17 00:00:00 2001 From: Sagar Tomar Date: Sun, 5 Dec 2021 14:23:49 +0530 Subject: [PATCH 1/3] Added check that nu must be a scalar in MvStudentTRV --- pymc/distributions/multivariate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 746400c9a2..01d76feaa2 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -262,6 +262,13 @@ class MvStudentTRV(RandomVariable): dtype = "floatX" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") + def make_node(self, rng, size, dtype, nu, mu, cov): + nu = at.as_tensor_variable(floatX(nu)) + if not nu.ndim == 0: + raise ValueError("nu must be a scalar (ndim=0).") + + return super().make_node(rng, size, dtype, nu, mu, cov) + def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype From f724d45d2c6dcaf971431960e67bbbce59a72d05 Mon Sep 17 00:00:00 2001 From: Sagar Tomar Date: Sat, 11 Dec 2021 12:28:29 +0530 Subject: [PATCH 2/3] Corrected make_node in MvStudentTRV --- pymc/distributions/multivariate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 01d76feaa2..c972755df8 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -263,7 +263,7 @@ class MvStudentTRV(RandomVariable): _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") def make_node(self, rng, size, dtype, nu, mu, cov): - nu = at.as_tensor_variable(floatX(nu)) + nu = at.as_tensor_variable(nu) if not nu.ndim == 0: raise ValueError("nu must be a scalar (ndim=0).") From 24316b2824b7b4e9c2c84c75fe9e42f2f05209e0 Mon Sep 17 00:00:00 2001 From: Sagar Tomar Date: Sat, 11 Dec 2021 19:54:13 +0530 Subject: [PATCH 3/3] Added test for checking exception is raised in case nu is not scalar in MvStudentT --- pymc/tests/test_distributions_random.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 0216de75fe..1e9788dd4e 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import itertools +import re from typing import Callable, List, Optional @@ -1115,8 +1116,20 @@ def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): "check_pymc_params_match_rv_op", "check_pymc_draws_match_reference", "check_rv_size", + "test_errors", ] + def test_errors(self): + msg = "nu must be a scalar (ndim=0)." + with pm.Model(): + with pytest.raises(ValueError, match=re.escape(msg)): + mvstudentt = pm.MvStudentT( + "mvstudentt", + nu=np.array([1, 2]), + mu=np.ones(2), + cov=np.full((2, 2), np.ones(2)), + ) + class TestMvStudentTChol(BaseTestDistribution): pymc_dist = pm.MvStudentT