Skip to content

Commit fe40965

Browse files
Updated tests + fix broken link
1 parent 6f677c2 commit fe40965

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ docker exec -it pymc jupyter notebook list
146146
## Style guide
147147

148148
We have configured a pre-commit hook that checks for `black`-compliant code style.
149-
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/PyMC-Python-Code-Style), because it will automatically enforce the code style on your commits.
149+
We encourage you to configure the pre-commit hook as described in the [PyMC Python Code Style Wiki Page](https://github.com/pymc-devs/pymc/wiki/Python-Code-Style), because it will automatically enforce the code style on your commits.
150150

151151
Similarly, consult the [PyMC's Jupyter Notebook Style](https://github.com/pymc-devs/pymc/wiki/PyMC-Jupyter-Notebook-Style-Guide) guides for notebooks.
152152

pymc/distributions/multivariate.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
from pymc.distributions.continuous import ChiSquared, Normal, assert_negative_support
4747
from pymc.distributions.dist_math import bound, factln, logpow, multigammaln
4848
from pymc.distributions.distribution import Continuous, Discrete
49-
from pymc.distributions.shape_utils import broadcast_dist_samples_to, rv_size_is_none, to_tuple
49+
from pymc.distributions.shape_utils import (
50+
broadcast_dist_samples_to,
51+
rv_size_is_none,
52+
to_tuple,
53+
)
5054
from pymc.math import kron_diag, kron_dot
5155

5256
__all__ = [
@@ -407,9 +411,11 @@ def dist(cls, a, **kwargs):
407411

408412
def get_moment(rv, size, a):
409413
norm_constant = at.sum(a, axis=-1)[..., None]
410-
moment = a/norm_constant
414+
moment = a / norm_constant
411415
if not rv_size_is_none(size):
412-
return at.full(size, moment)
416+
if isinstance(size, int):
417+
size = (size,)
418+
moment = at.full((*size, *a.shape), moment)
413419
return moment
414420

415421
def logp(value, a):

pymc/tests/test_distributions_moments.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
Cauchy,
1111
ChiSquared,
1212
Constant,
13+
Dirichlet,
1314
DiscreteUniform,
1415
ExGaussian,
15-
Dirichlet,
1616
Exponential,
1717
Flat,
1818
Gamma,
@@ -613,25 +613,38 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
613613
with Model() as model:
614614
DiscreteUniform("x", lower=lower, upper=upper, size=size)
615615

616+
616617
@pytest.mark.parametrize(
617618
"a, size, expected",
618619
[
619620
(
620-
np.array([2, 3, 5, 7, 11]),
621-
None,
622-
np.array([2, 3, 5, 7, 11])/28,
621+
np.array([2, 3, 5, 7, 11]),
622+
None,
623+
np.array([2, 3, 5, 7, 11]) / 28,
623624
),
624625
(
625626
np.array([[1, 2, 3], [5, 6, 7]]),
626627
None,
627-
np.array([[1, 2, 3], [5, 6, 7]])/np.array([6, 18])[..., np.newaxis],
628+
np.array([[1, 2, 3], [5, 6, 7]]) / np.array([6, 18])[..., np.newaxis],
629+
),
630+
(
631+
np.array([[1, 2, 3], [5, 6, 7]]),
632+
7,
633+
np.apply_along_axis(
634+
lambda x: np.divide(x, np.array([6, 18])),
635+
1,
636+
np.broadcast_to([[1, 2, 3], [5, 6, 7]], shape=[7, 2, 3]),
637+
),
628638
),
629639
(
630640
np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])),
631-
(11, 5,),
632-
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]),
641+
(
642+
11,
643+
5,
644+
),
645+
np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49,
633646
),
634-
]
647+
],
635648
)
636649
def test_dirichlet_moment(a, size, expected):
637650
with Model() as model:

0 commit comments

Comments
 (0)