|
| 1 | +import aesara |
1 | 2 | import numpy as np
|
2 | 3 | import pytest
|
3 | 4 |
|
| 5 | +from aesara import tensor as at |
4 | 6 | from scipy import special
|
5 | 7 |
|
| 8 | +import pymc as pm |
| 9 | + |
6 | 10 | from pymc.distributions import (
|
7 | 11 | AsymmetricLaplace,
|
8 | 12 | Bernoulli,
|
|
13 | 17 | Cauchy,
|
14 | 18 | ChiSquared,
|
15 | 19 | Constant,
|
| 20 | + DensityDist, |
16 | 21 | Dirichlet,
|
17 | 22 | DiscreteUniform,
|
18 | 23 | ExGaussian,
|
|
48 | 53 | ZeroInflatedBinomial,
|
49 | 54 | ZeroInflatedPoisson,
|
50 | 55 | )
|
| 56 | +from pymc.distributions.distribution import get_moment |
51 | 57 | from pymc.distributions.multivariate import MvNormal
|
52 |
| -from pymc.distributions.shape_utils import rv_size_is_none |
| 58 | +from pymc.distributions.shape_utils import rv_size_is_none, to_tuple |
53 | 59 | from pymc.initial_point import make_initial_point_fn
|
54 | 60 | from pymc.model import Model
|
55 | 61 |
|
@@ -919,3 +925,83 @@ def test_rice_moment(nu, sigma, size, expected):
|
919 | 925 | with Model() as model:
|
920 | 926 | Rice("x", nu=nu, sigma=sigma, size=size)
|
921 | 927 | assert_moment_is_expected(model, expected)
|
| 928 | + |
| 929 | + |
| 930 | +@pytest.mark.parametrize( |
| 931 | + "get_moment, size, expected", |
| 932 | + [ |
| 933 | + (None, None, 0.0), |
| 934 | + (None, 5, np.zeros(5)), |
| 935 | + ("custom_moment", None, 5), |
| 936 | + ("custom_moment", (2, 5), np.full((2, 5), 5)), |
| 937 | + ], |
| 938 | +) |
| 939 | +def test_density_dist_default_moment_univariate(get_moment, size, expected): |
| 940 | + if get_moment == "custom_moment": |
| 941 | + get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype) |
| 942 | + with Model() as model: |
| 943 | + DensityDist("x", get_moment=get_moment, size=size) |
| 944 | + assert_moment_is_expected(model, expected) |
| 945 | + |
| 946 | + |
| 947 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 948 | +def test_density_dist_custom_moment_univariate(size): |
| 949 | + def moment(rv, size, mu): |
| 950 | + return (at.ones(size) * mu).astype(rv.dtype) |
| 951 | + |
| 952 | + mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX) |
| 953 | + with pm.Model(): |
| 954 | + mu = pm.Normal("mu") |
| 955 | + a = pm.DensityDist("a", mu, get_moment=moment, size=size) |
| 956 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 957 | + assert evaled_moment.shape == to_tuple(size) |
| 958 | + assert np.all(evaled_moment == mu_val) |
| 959 | + |
| 960 | + |
| 961 | +@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str) |
| 962 | +def test_density_dist_custom_moment_multivariate(size): |
| 963 | + def moment(rv, size, mu): |
| 964 | + return (at.ones(size)[..., None] * mu).astype(rv.dtype) |
| 965 | + |
| 966 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 967 | + with pm.Model(): |
| 968 | + mu = pm.Normal("mu", size=5) |
| 969 | + a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size) |
| 970 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 971 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 972 | + assert np.all(evaled_moment == mu_val) |
| 973 | + |
| 974 | + |
| 975 | +@pytest.mark.parametrize( |
| 976 | + "with_random, size", |
| 977 | + [ |
| 978 | + (True, ()), |
| 979 | + (True, (2,)), |
| 980 | + (True, (3, 2)), |
| 981 | + (False, ()), |
| 982 | + (False, (2,)), |
| 983 | + ], |
| 984 | +) |
| 985 | +def test_density_dist_default_moment_multivariate(with_random, size): |
| 986 | + def _random(mu, rng=None, size=None): |
| 987 | + return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape) |
| 988 | + |
| 989 | + if with_random: |
| 990 | + random = _random |
| 991 | + else: |
| 992 | + random = None |
| 993 | + |
| 994 | + mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX) |
| 995 | + with pm.Model(): |
| 996 | + mu = pm.Normal("mu", size=5) |
| 997 | + a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size) |
| 998 | + if with_random: |
| 999 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
| 1000 | + assert evaled_moment.shape == to_tuple(size) + (5,) |
| 1001 | + assert np.all(evaled_moment == 0) |
| 1002 | + else: |
| 1003 | + with pytest.raises( |
| 1004 | + TypeError, |
| 1005 | + match="Cannot safely infer the size of a multivariate random variable's moment.", |
| 1006 | + ): |
| 1007 | + evaled_moment = get_moment(a).eval({mu: mu_val}) |
0 commit comments