Skip to content

Commit 90860e6

Browse files
authored
Add DensityDist moment (#5159)
* Add DensityDist moment * Specialize get_moment for multivariate density dists
1 parent c0c5a80 commit 90860e6

File tree

3 files changed

+112
-45
lines changed

3 files changed

+112
-45
lines changed

pymc/distributions/distribution.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import aesara
2525

2626
from aeppl.logprob import _logcdf, _logprob
27+
from aesara import tensor as at
2728
from aesara.tensor.basic import as_tensor_variable
2829
from aesara.tensor.random.op import RandomVariable
2930
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -472,9 +473,9 @@ def __new__(
472473
as the first argument ``rv``. ``size`` is the random variable's size implied
473474
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
474475
``rv_inputs`` is the sequence of the distribution parameters, in the same order
475-
as they were supplied when the DensityDist was created. If ``None``, a
476-
``NotImplemented`` error will be raised when trying to draw random samples from
477-
the distribution's prior or posterior predictive.
476+
as they were supplied when the DensityDist was created. If ``None``, a default
477+
``get_moment`` function will be assigned that will always return 0, or an array
478+
of zeros.
478479
ndim_supp : int
479480
The number of dimensions in the support of the distribution. Defaults to assuming
480481
a scalar distribution, i.e. ``ndim_supp = 0``.
@@ -550,12 +551,17 @@ def random(mu, rng=None, size=None):
550551
if logcdf is None:
551552
logcdf = default_not_implemented(name, "logcdf")
552553

554+
if get_moment is None:
555+
get_moment = functools.partial(
556+
default_get_moment,
557+
rv_name=name,
558+
has_fallback=random is not None,
559+
ndim_supp=ndim_supp,
560+
)
561+
553562
if random is None:
554563
random = default_not_implemented(name, "random")
555564

556-
if get_moment is None:
557-
get_moment = default_not_implemented(name, "get_moment")
558-
559565
rv_op = type(
560566
f"DensityDist_{name}",
561567
(DensityDistRV,),
@@ -614,3 +620,16 @@ def func(*args, **kwargs):
614620
raise NotImplementedError(message)
615621

616622
return func
623+
624+
625+
def default_get_moment(rv, size, *rv_inputs, rv_name=None, has_fallback=False, ndim_supp=0):
626+
if ndim_supp == 0:
627+
return at.zeros(size, dtype=rv.dtype)
628+
elif has_fallback:
629+
return at.zeros_like(rv)
630+
else:
631+
raise TypeError(
632+
"Cannot safely infer the size of a multivariate random variable's moment. "
633+
f"Please provide a get_moment function when instantiating the {rv_name} "
634+
"random variable."
635+
)

pymc/tests/test_distributions_moments.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import aesara
12
import numpy as np
23
import pytest
34

5+
from aesara import tensor as at
46
from scipy import special
57

8+
import pymc as pm
9+
610
from pymc.distributions import (
711
AsymmetricLaplace,
812
Bernoulli,
@@ -13,6 +17,7 @@
1317
Cauchy,
1418
ChiSquared,
1519
Constant,
20+
DensityDist,
1621
Dirichlet,
1722
DiscreteUniform,
1823
ExGaussian,
@@ -47,8 +52,9 @@
4752
ZeroInflatedBinomial,
4853
ZeroInflatedPoisson,
4954
)
55+
from pymc.distributions.distribution import get_moment
5056
from pymc.distributions.multivariate import MvNormal
51-
from pymc.distributions.shape_utils import rv_size_is_none
57+
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
5258
from pymc.initial_point import make_initial_point_fn
5359
from pymc.model import Model
5460

@@ -898,3 +904,83 @@ def test_rice_moment(nu, sigma, size, expected):
898904
with Model() as model:
899905
Rice("x", nu=nu, sigma=sigma, size=size)
900906
assert_moment_is_expected(model, expected)
907+
908+
909+
@pytest.mark.parametrize(
910+
"get_moment, size, expected",
911+
[
912+
(None, None, 0.0),
913+
(None, 5, np.zeros(5)),
914+
("custom_moment", None, 5),
915+
("custom_moment", (2, 5), np.full((2, 5), 5)),
916+
],
917+
)
918+
def test_density_dist_default_moment_univariate(get_moment, size, expected):
919+
if get_moment == "custom_moment":
920+
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
921+
with Model() as model:
922+
DensityDist("x", get_moment=get_moment, size=size)
923+
assert_moment_is_expected(model, expected)
924+
925+
926+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
927+
def test_density_dist_custom_moment_univariate(size):
928+
def moment(rv, size, mu):
929+
return (at.ones(size) * mu).astype(rv.dtype)
930+
931+
mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX)
932+
with pm.Model():
933+
mu = pm.Normal("mu")
934+
a = pm.DensityDist("a", mu, get_moment=moment, size=size)
935+
evaled_moment = get_moment(a).eval({mu: mu_val})
936+
assert evaled_moment.shape == to_tuple(size)
937+
assert np.all(evaled_moment == mu_val)
938+
939+
940+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
941+
def test_density_dist_custom_moment_multivariate(size):
942+
def moment(rv, size, mu):
943+
return (at.ones(size)[..., None] * mu).astype(rv.dtype)
944+
945+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
946+
with pm.Model():
947+
mu = pm.Normal("mu", size=5)
948+
a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size)
949+
evaled_moment = get_moment(a).eval({mu: mu_val})
950+
assert evaled_moment.shape == to_tuple(size) + (5,)
951+
assert np.all(evaled_moment == mu_val)
952+
953+
954+
@pytest.mark.parametrize(
955+
"with_random, size",
956+
[
957+
(True, ()),
958+
(True, (2,)),
959+
(True, (3, 2)),
960+
(False, ()),
961+
(False, (2,)),
962+
],
963+
)
964+
def test_density_dist_default_moment_multivariate(with_random, size):
965+
def _random(mu, rng=None, size=None):
966+
return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape)
967+
968+
if with_random:
969+
random = _random
970+
else:
971+
random = None
972+
973+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
974+
with pm.Model():
975+
mu = pm.Normal("mu", size=5)
976+
a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
977+
if with_random:
978+
evaled_moment = get_moment(a).eval({mu: mu_val})
979+
assert evaled_moment.shape == to_tuple(size) + (5,)
980+
assert np.all(evaled_moment == 0)
981+
else:
982+
with pytest.raises(
983+
TypeError,
984+
match="Cannot safely infer the size of a multivariate random variable's moment.",
985+
):
986+
evaled_moment = get_moment(a).eval({mu: mu_val})

pymc/tests/test_moment.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)