Skip to content

Commit 8d1708a

Browse files
Adding Dirichlet moment and tests (#5174)
1 parent 7ec5ca1 commit 8d1708a

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ docker exec -it pymc jupyter notebook list
146146
## Style guide
147147

148148
We have configured a pre-commit hook that checks for `black`-compliant code style.
149-
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/PyMC-Python-Code-Style), because it will automatically enforce the code style on your commits.
149+
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/Python-Code-Style), because it will automatically enforce the code style on your commits.
150150

151151
Similarly, consult the [PyMC's Jupyter Notebook Style](https://github.com/pymc-devs/pymc/wiki/PyMC-Jupyter-Notebook-Style-Guide) guides for notebooks.
152152

pymc/distributions/multivariate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
4747
from pymc.distributions.dist_math import bound, factln, logpow, multigammaln
4848
from pymc.distributions.distribution import Continuous, Discrete
49-
from pymc.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
49+
from pymc.distributions.shape_utils import (
50+
broadcast_dist_samples_to,
51+
rv_size_is_none,
52+
to_tuple,
53+
)
5054
from pymc.math import kron_diag, kron_dot
5155

5256
__all__ = [
@@ -405,6 +409,15 @@ def dist(cls, a, **kwargs):
405409

406410
return super().dist([a], **kwargs)
407411

412+
def get_moment(rv, size, a):
413+
norm_constant = at.sum(a, axis=-1)[..., None]
414+
moment = a / norm_constant
415+
if not rv_size_is_none(size):
416+
if isinstance(size, int):
417+
size = (size,)
418+
moment = at.full((*size, *a.shape), moment)
419+
return moment
420+
408421
def logp(value, a):
409422
"""
410423
Calculate log-probability of Dirichlet distribution

pymc/tests/test_distributions_moments.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Cauchy,
1111
ChiSquared,
1212
Constant,
13+
Dirichlet,
1314
DiscreteUniform,
1415
ExGaussian,
1516
Exponential,
@@ -611,4 +612,41 @@ def test_hyper_geometric_moment(N, k, n, size, expected):
611612
def test_discrete_uniform_moment(lower, upper, size, expected):
612613
with Model() as model:
613614
DiscreteUniform("x", lower=lower, upper=upper, size=size)
615+
616+
617+
@pytest.mark.parametrize(
618+
"a, size, expected",
619+
[
620+
(
621+
np.array([2, 3, 5, 7, 11]),
622+
None,
623+
np.array([2, 3, 5, 7, 11]) / 28,
624+
),
625+
(
626+
np.array([[1, 2, 3], [5, 6, 7]]),
627+
None,
628+
np.array([[1, 2, 3], [5, 6, 7]]) / np.array([6, 18])[..., np.newaxis],
629+
),
630+
(
631+
np.array([[1, 2, 3], [5, 6, 7]]),
632+
7,
633+
np.apply_along_axis(
634+
lambda x: np.divide(x, np.array([6, 18])),
635+
1,
636+
np.broadcast_to([[1, 2, 3], [5, 6, 7]], shape=[7, 2, 3]),
637+
),
638+
),
639+
(
640+
np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
641+
(
642+
11,
643+
5,
644+
),
645+
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49,
646+
),
647+
],
648+
)
649+
def test_dirichlet_moment(a, size, expected):
650+
with Model() as model:
651+
Dirichlet("x", a=a, size=size)
614652
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)