|
| 1 | +import aesara |
1 | 2 | import numpy as np
|
2 | 3 | import pytest
|
3 | 4 |
|
| 5 | +from aesara import tensor as at |
4 | 6 | from scipy import special
|
5 | 7 |
|
| 8 | +import pymc as pm |
| 9 | + |
6 | 10 | from pymc.distributions import (
|
7 | 11 | AsymmetricLaplace,
|
8 | 12 | Bernoulli,
|
|
13 | 17 | Cauchy,
|
14 | 18 | ChiSquared,
|
15 | 19 | Constant,
|
| 20 | + DensityDist, |
16 | 21 | Dirichlet,
|
17 | 22 | DiscreteUniform,
|
18 | 23 | ExGaussian,
|
|
47 | 52 | ZeroInflatedBinomial,
|
48 | 53 | ZeroInflatedPoisson,
|
49 | 54 | )
|
| 55 | +from pymc.distributions.distribution import get_moment |
50 | 56 | from pymc.distributions.multivariate import MvNormal
|
51 |
| -from pymc.distributions.shape_utils import rv_size_is_none |
| 57 | +from pymc.distributions.shape_utils import rv_size_is_none, to_tuple |
52 | 58 | from pymc.initial_point import make_initial_point_fn
|
53 | 59 | from pymc.model import Model
|
54 | 60 |
|
@@ -898,3 +904,83 @@ def test_rice_moment(nu, sigma, size, expected):
|
898 | 904 | with Model() as model:
|
899 | 905 | Rice("x", nu=nu, sigma=sigma, size=size)
|
900 | 906 | assert_moment_is_expected(model, expected)
|
| 907 | + |
| 908 | + |
| 909 | +@pytest.mark.parametrize( |
| 910 | + "get_moment, size, expected", |
| 911 | + [ |
| 912 | + (None, None, 0.0), |
| 913 | + (None, 5, np.zeros(5)), |
| 914 | + ("custom_moment", None, 5), |
| 915 | + ("custom_moment", (2, 5), np.full((2, 5), 5)), |
| 916 | + ], |
| 917 | +) |
| 918 | +def test_density_dist_default_moment_univariate(get_moment, size, expected): |
| 919 | + if get_moment == "custom_moment": |
| 920 | + get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype) |
| 921 | + with Model() as model: |
| 922 | + DensityDist("x", get_moment=get_moment, size=size) |
| 923 | + assert_moment_is_expected(model, expected) |
| 924 | + |
| 925 | + |
| 926 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 927 | +def test_density_dist_custom_moment_univariate(size): |
| 928 | + def moment(rv, size, mu): |
| 929 | + return (at.ones(size) * mu).astype(rv.dtype) |
| 930 | + |
| 931 | + mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX) |
| 932 | + with pm.Model(): |
| 933 | + mu = pm.Normal("mu") |
| 934 | + a = pm.DensityDist("a", mu, get_moment=moment, size=size) |
| 935 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 936 | + assert evaled_moment.shape == to_tuple(size) |
| 937 | + assert np.all(evaled_moment == mu_val) |
| 938 | + |
| 939 | + |
| 940 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 941 | +def test_density_dist_custom_moment_multivariate(size): |
| 942 | + def moment(rv, size, mu): |
| 943 | + return (at.ones(size)[..., None] * mu).astype(rv.dtype) |
| 944 | + |
| 945 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 946 | + with pm.Model(): |
| 947 | + mu = pm.Normal("mu", size=5) |
| 948 | + a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size) |
| 949 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 950 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 951 | + assert np.all(evaled_moment == mu_val) |
| 952 | + |
| 953 | + |
| 954 | +@pytest.mark.parametrize( |
| 955 | + "with_random, size", |
| 956 | + [ |
| 957 | + (True, ()), |
| 958 | + (True, (2,)), |
| 959 | + (True, (3, 2)), |
| 960 | + (False, ()), |
| 961 | + (False, (2,)), |
| 962 | + ], |
| 963 | +) |
| 964 | +def test_density_dist_default_moment_multivariate(with_random, size): |
| 965 | + def _random(mu, rng=None, size=None): |
| 966 | + return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) |
| 967 | + |
| 968 | + if with_random: |
| 969 | + random = _random |
| 970 | + else: |
| 971 | + random = None |
| 972 | + |
| 973 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 974 | + with pm.Model(): |
| 975 | + mu = pm.Normal("mu", size=5) |
| 976 | + a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) |
| 977 | + if with_random: |
| 978 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 979 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 980 | + assert np.all(evaled_moment == 0) |
| 981 | + else: |
| 982 | + with pytest.raises( |
| 983 | + TypeError, |
| 984 | + match="Cannot safely infer the size of a multivariate random variable's moment.", |
| 985 | + ): |
| 986 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
0 commit comments