diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ee05e2f582..d25c91ad85 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -146,7 +146,7 @@ docker exec -it pymc jupyter notebook list ## Style guide We have configured a pre-commit hook that checks for `black`-compliant code style. -We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/PyMC-Python-Code-Style), because it will automatically enforce the code style on your commits. +We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/Python-Code-Style), because it will automatically enforce the code style on your commits. Similarly, consult the [PyMC's Jupyter Notebook Style](https://github.com/pymc-devs/pymc/wiki/PyMC-Jupyter-Notebook-Style-Guide) guides for notebooks. diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index cfa372e037..08a848de6f 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -46,7 +46,11 @@ from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support from pymc.distributions.dist_math import bound, factln, logpow, multigammaln from pymc.distributions.distribution import Continuous, Discrete -from pymc.distributions.shape_utils import broadcast_dist_samples_to, to_tuple +from pymc.distributions.shape_utils import ( + broadcast_dist_samples_to, + rv_size_is_none, + to_tuple, +) from pymc.math import kron_diag, kron_dot __all__ = [ @@ -405,6 +409,15 @@ def dist(cls, a, **kwargs): return super().dist([a], **kwargs) + def get_moment(rv, size, a): + norm_constant = at.sum(a, axis=-1)[..., None] + moment = a / norm_constant + if not rv_size_is_none(size): + if isinstance(size, int): + size = (size,) + moment = at.full((*size, *a.shape), moment) + return moment + def logp(value, a): """ Calculate log-probability of Dirichlet distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index f52ca8baa0..bbe478f8ac 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -10,6 +10,7 @@ Cauchy, ChiSquared, Constant, + Dirichlet, DiscreteUniform, ExGaussian, Exponential, @@ -611,4 +612,41 @@ def test_hyper_geometric_moment(N, k, n, size, expected): def test_discrete_uniform_moment(lower, upper, size, expected): with Model() as model: DiscreteUniform("x", lower=lower, upper=upper, size=size) + + +@pytest.mark.parametrize( + "a, size, expected", + [ + ( + np.array([2, 3, 5, 7, 11]), + None, + np.array([2, 3, 5, 7, 11]) / 28, + ), + ( + np.array([[1, 2, 3], [5, 6, 7]]), + None, + np.array([[1, 2, 3], [5, 6, 7]]) / np.array([6, 18])[..., np.newaxis], + ), + ( + np.array([[1, 2, 3], [5, 6, 7]]), + 7, + np.apply_along_axis( + lambda x: np.divide(x, np.array([6, 18])), + 1, + np.broadcast_to([[1, 2, 3], [5, 6, 7]], shape=[7, 2, 3]), + ), + ), + ( + np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])), + ( + 11, + 5, + ), + np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49, + ), + ], +) +def test_dirichlet_moment(a, size, expected): + with Model() as model: + Dirichlet("x", a=a, size=size) assert_moment_is_expected(model, expected)