diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 4dbc1ea9b8..9c3bd37a91 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -525,6 +525,21 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], *args, **kwargs) + def get_moment(rv, size, n, p): + if p.ndim > 1: + n = at.shape_padright(n) + if (p.ndim == 1) & (n.ndim > 0): + n = at.shape_padright(n) + p = at.shape_padleft(p) + mode = at.round(n * p) + diff = n - at.sum(mode, axis=-1, keepdims=True) + inc_bool_arr = at.abs_(diff) > 0 + mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) + if not rv_size_is_none(size): + output_size = at.concatenate([size, p.shape]) + mode = at.full(output_size, mode) + return mode + def logp(value, n, p): """ Calculate log-probability of Multinomial distribution diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 9cd1ef5424..b19f24543b 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2160,25 +2160,6 @@ def test_multinomial(self, n): Multinomial, Vector(Nat, n), {"p": Simplex(n), "n": Nat}, multinomial_logpdf ) - @pytest.mark.skip(reason="Moment calculations have not been refactored yet") - @pytest.mark.parametrize( - "p,n", - [ - [[0.25, 0.25, 0.25, 0.25], 1], - [[0.3, 0.6, 0.05, 0.05], 2], - [[0.3, 0.6, 0.05, 0.05], 10], - ], - ) - def test_multinomial_mode(self, p, n): - _p = np.array(p) - with Model() as model: - m = Multinomial("m", n, _p, _p.shape) - assert_allclose(m.distribution.mode.eval().sum(), n) - _p = np.array([p, p]) - with Model() as model: - m = Multinomial("m", n, _p, _p.shape) - assert_allclose(m.distribution.mode.eval().sum(axis=-1), n) - @pytest.mark.parametrize( "p, size, n", [ @@ -2206,14 +2187,6 @@ def test_multinomial_random(self, p, size, n): assert m.eval().shape == size + p.shape - @pytest.mark.skip(reason="Moment calculations have not been refactored yet") - def test_multinomial_mode_with_shape(self): - n = [1, 10] - p = np.asarray([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]) - with Model() as model: - m = Multinomial("m", n=n, p=p, size=(2, 4)) - assert_allclose(m.distribution.mode.eval().sum(axis=-1), n) - def test_multinomial_vec(self): vals = np.array([[2, 4, 4], [3, 3, 4]]) p = np.array([0.2, 0.3, 0.5]) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 8b851c0044..ad6e234305 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -41,6 +41,7 @@ LogNormal, MatrixNormal, Moyal, + Multinomial, MvStudentT, NegativeBinomial, Normal, @@ -1104,6 +1105,44 @@ def test_polyagamma_moment(h, z, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "p, n, size, expected", + [ + (np.array([0.25, 0.25, 0.25, 0.25]), 1, None, np.array([1, 0, 0, 0])), + (np.array([0.3, 0.6, 0.05, 0.05]), 2, None, np.array([1, 1, 0, 0])), + (np.array([0.3, 0.6, 0.05, 0.05]), 10, None, np.array([4, 6, 0, 0])), + ( + np.array([[0.3, 0.6, 0.05, 0.05], [0.25, 0.25, 0.25, 0.25]]), + 10, + None, + np.array([[4, 6, 0, 0], [4, 2, 2, 2]]), + ), + ( + np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), + np.array([1, 10]), + None, + np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), + ), + ( + np.array([0.26, 0.26, 0.26, 0.22]), + np.array([1, 10]), + None, + np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), + ), + ( + np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), + np.array([1, 10]), + 2, + np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]), + ), + ], +) +def test_multinomial_moment(p, n, size, expected): + with Model() as model: + Multinomial("x", n=n, p=p, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "psi, mu, alpha, size, expected", [