Skip to content

Commit d36b16d

Browse files
lucianopazmorganstrom
authored andcommitted
Add DensityDist moment (pymc-devs#5159)
* Add DensityDist moment * Specialize get_moment for multivariate density dists
1 parent 6d75084 commit d36b16d

File tree

3 files changed

+112
-45
lines changed

3 files changed

+112
-45
lines changed

pymc/distributions/distribution.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import aesara
2525

2626
from aeppl.logprob import _logcdf, _logprob
27+
from aesara import tensor as at
2728
from aesara.tensor.basic import as_tensor_variable
2829
from aesara.tensor.random.op import RandomVariable
2930
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -472,9 +473,9 @@ def __new__(
472473
as the first argument ``rv``. ``size`` is the random variable's size implied
473474
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
474475
``rv_inputs`` is the sequence of the distribution parameters, in the same order
475-
as they were supplied when the DensityDist was created. If ``None``, a
476-
``NotImplemented`` error will be raised when trying to draw random samples from
477-
the distribution's prior or posterior predictive.
476+
as they were supplied when the DensityDist was created. If ``None``, a default
477+
``get_moment`` function will be assigned that will always return 0, or an array
478+
of zeros.
478479
ndim_supp : int
479480
The number of dimensions in the support of the distribution. Defaults to assuming
480481
a scalar distribution, i.e. ``ndim_supp = 0``.
@@ -550,12 +551,17 @@ def random(mu, rng=None, size=None):
550551
if logcdf is None:
551552
logcdf = default_not_implemented(name, "logcdf")
552553

554+
if get_moment is None:
555+
get_moment = functools.partial(
556+
default_get_moment,
557+
rv_name=name,
558+
has_fallback=random is not None,
559+
ndim_supp=ndim_supp,
560+
)
561+
553562
if random is None:
554563
random = default_not_implemented(name, "random")
555564

556-
if get_moment is None:
557-
get_moment = default_not_implemented(name, "get_moment")
558-
559565
rv_op = type(
560566
f"DensityDist_{name}",
561567
(DensityDistRV,),
@@ -614,3 +620,16 @@ def func(*args, **kwargs):
614620
raise NotImplementedError(message)
615621

616622
return func
623+
624+
625+
def default_get_moment(rv, size, *rv_inputs, rv_name=None, has_fallback=False, ndim_supp=0):
626+
if ndim_supp == 0:
627+
return at.zeros(size, dtype=rv.dtype)
628+
elif has_fallback:
629+
return at.zeros_like(rv)
630+
else:
631+
raise TypeError(
632+
"Cannot safely infer the size of a multivariate random variable's moment. "
633+
f"Please provide a get_moment function when instantiating the {rv_name} "
634+
"random variable."
635+
)

pymc/tests/test_distributions_moments.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import aesara
12
import numpy as np
23
import pytest
34

5+
from aesara import tensor as at
46
from scipy import special
57

8+
import pymc as pm
9+
610
from pymc.distributions import (
711
AsymmetricLaplace,
812
Bernoulli,
@@ -13,6 +17,7 @@
1317
Cauchy,
1418
ChiSquared,
1519
Constant,
20+
DensityDist,
1621
Dirichlet,
1722
DiscreteUniform,
1823
ExGaussian,
@@ -48,8 +53,9 @@
4853
ZeroInflatedBinomial,
4954
ZeroInflatedPoisson,
5055
)
56+
from pymc.distributions.distribution import get_moment
5157
from pymc.distributions.multivariate import MvNormal
52-
from pymc.distributions.shape_utils import rv_size_is_none
58+
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
5359
from pymc.initial_point import make_initial_point_fn
5460
from pymc.model import Model
5561

@@ -919,3 +925,83 @@ def test_rice_moment(nu, sigma, size, expected):
919925
with Model() as model:
920926
Rice("x", nu=nu, sigma=sigma, size=size)
921927
assert_moment_is_expected(model, expected)
928+
929+
930+
@pytest.mark.parametrize(
931+
"get_moment, size, expected",
932+
[
933+
(None, None, 0.0),
934+
(None, 5, np.zeros(5)),
935+
("custom_moment", None, 5),
936+
("custom_moment", (2, 5), np.full((2, 5), 5)),
937+
],
938+
)
939+
def test_density_dist_default_moment_univariate(get_moment, size, expected):
940+
if get_moment == "custom_moment":
941+
get_moment = lambda rv, size, *rv_inputs: 5 * at.ones(size, dtype=rv.dtype)
942+
with Model() as model:
943+
DensityDist("x", get_moment=get_moment, size=size)
944+
assert_moment_is_expected(model, expected)
945+
946+
947+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
948+
def test_density_dist_custom_moment_univariate(size):
949+
def moment(rv, size, mu):
950+
return (at.ones(size) * mu).astype(rv.dtype)
951+
952+
mu_val = np.array(np.random.normal(loc=2, scale=1)).astype(aesara.config.floatX)
953+
with pm.Model():
954+
mu = pm.Normal("mu")
955+
a = pm.DensityDist("a", mu, get_moment=moment, size=size)
956+
evaled_moment = get_moment(a).eval({mu: mu_val})
957+
assert evaled_moment.shape == to_tuple(size)
958+
assert np.all(evaled_moment == mu_val)
959+
960+
961+
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
962+
def test_density_dist_custom_moment_multivariate(size):
963+
def moment(rv, size, mu):
964+
return (at.ones(size)[..., None] * mu).astype(rv.dtype)
965+
966+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
967+
with pm.Model():
968+
mu = pm.Normal("mu", size=5)
969+
a = pm.DensityDist("a", mu, get_moment=moment, ndims_params=[1], ndim_supp=1, size=size)
970+
evaled_moment = get_moment(a).eval({mu: mu_val})
971+
assert evaled_moment.shape == to_tuple(size) + (5,)
972+
assert np.all(evaled_moment == mu_val)
973+
974+
975+
@pytest.mark.parametrize(
976+
"with_random, size",
977+
[
978+
(True, ()),
979+
(True, (2,)),
980+
(True, (3, 2)),
981+
(False, ()),
982+
(False, (2,)),
983+
],
984+
)
985+
def test_density_dist_default_moment_multivariate(with_random, size):
986+
def _random(mu, rng=None, size=None):
987+
return rng.normal(mu, scale=1, size=to_tuple(size) + mu.shape)
988+
989+
if with_random:
990+
random = _random
991+
else:
992+
random = None
993+
994+
mu_val = np.random.normal(loc=2, scale=1, size=5).astype(aesara.config.floatX)
995+
with pm.Model():
996+
mu = pm.Normal("mu", size=5)
997+
a = pm.DensityDist("a", mu, random=random, ndims_params=[1], ndim_supp=1, size=size)
998+
if with_random:
999+
evaled_moment = get_moment(a).eval({mu: mu_val})
1000+
assert evaled_moment.shape == to_tuple(size) + (5,)
1001+
assert np.all(evaled_moment == 0)
1002+
else:
1003+
with pytest.raises(
1004+
TypeError,
1005+
match="Cannot safely infer the size of a multivariate random variable's moment.",
1006+
):
1007+
evaled_moment = get_moment(a).eval({mu: mu_val})

pymc/tests/test_moment.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)