diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 97ad4035e1..50f3f1f368 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -2924,6 +2924,12 @@ def dist(cls, mu=0.0, kappa=None, *args, **kwargs): assert_negative_support(kappa, "kappa", "VonMises") return super().dist([mu, kappa], *args, **kwargs) + def get_moment(rv, size, mu, kappa): + mu, _ = at.broadcast_arrays(mu, kappa) + if not rv_size_is_none(size): + mu = at.full(size, mu) + return mu + class SkewNormalRV(RandomVariable): name = "skewnormal" diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index ad6e234305..b5d1b0cfd9 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -54,6 +54,7 @@ Triangular, TruncatedNormal, Uniform, + VonMises, Wald, Weibull, ZeroInflatedBinomial, @@ -438,6 +439,21 @@ def test_pareto_moment(alpha, m, size, expected): assert_moment_is_expected(model, expected) +@pytest.mark.parametrize( + "mu, kappa, size, expected", + [ + (0, 1, None, 0), + (0, np.ones(4), None, np.zeros(4)), + (np.arange(4), 0.5, None, np.arange(4)), + (np.arange(4), np.arange(1, 5), (2, 4), np.full((2, 4), np.arange(4))), + ], +) +def test_vonmises_moment(mu, kappa, size, expected): + with Model() as model: + VonMises("x", mu=mu, kappa=kappa, size=size) + assert_moment_is_expected(model, expected) + + @pytest.mark.parametrize( "mu, lam, phi, size, expected", [