Skip to content

Commit 7ddc9d4

Browse files
committed
Add tests for distribution moments
Fixes some pre-existing moments Adds moments for HalfNormal and TruncatedNormal distributions Adds helper rv_size_is_none function
1 parent 05aa247 commit 7ddc9d4

File tree

6 files changed

+208
-34
lines changed

6 files changed

+208
-34
lines changed

.github/workflows/pytest.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ jobs:
155155
- |
156156
pymc/tests/test_initial_point.py
157157
pymc/tests/test_distributions_random.py
158+
pymc/tests/test_distributions_moments.py
158159
pymc/tests/test_distributions_timeseries.py
159160
- |
160161
pymc/tests/test_parallel_sampling.py

pymc/distributions/continuous.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
8585
zvalue,
8686
)
8787
from pymc.distributions.distribution import Continuous
88+
from pymc.distributions.shape_utils import rv_size_is_none
8889
from pymc.math import logdiffexp, logit
8990
from pymc.util import UNSET
9091

@@ -290,6 +291,13 @@ def dist(cls, lower=0, upper=1, **kwargs):
290291
upper = at.as_tensor_variable(floatX(upper))
291292
return super().dist([lower, upper], **kwargs)
292293

294+
def get_moment(rv, size, lower, upper):
295+
lower, upper = at.broadcast_arrays(lower, upper)
296+
moment = (lower + upper) / 2
297+
if not rv_size_is_none(size):
298+
moment = at.full(size, moment)
299+
return moment
300+
293301
def logcdf(value, lower, upper):
294302
"""
295303
Compute the log of the cumulative distribution function for Uniform distribution
@@ -315,11 +323,6 @@ def logcdf(value, lower, upper):
315323
),
316324
)
317325

318-
def get_moment(value, size, lower, upper):
319-
lower = at.full(size, lower, dtype=aesara.config.floatX)
320-
upper = at.full(size, upper, dtype=aesara.config.floatX)
321-
return (lower + upper) / 2
322-
323326

324327
class FlatRV(RandomVariable):
325328
name = "flat"
@@ -353,8 +356,8 @@ def dist(cls, *, size=None, **kwargs):
353356
res = super().dist([], size=size, **kwargs)
354357
return res
355358

356-
def get_moment(rv, size, *rv_inputs):
357-
return at.zeros(size, dtype=aesara.config.floatX)
359+
def get_moment(rv, size):
360+
return at.zeros(size)
358361

359362
def logp(value):
360363
"""
@@ -421,8 +424,8 @@ def dist(cls, *, size=None, **kwargs):
421424
res = super().dist([], size=size, **kwargs)
422425
return res
423426

424-
def get_moment(value_var, size, *rv_inputs):
425-
return at.ones(size, dtype=aesara.config.floatX)
427+
def get_moment(rv, size):
428+
return at.ones(size)
426429

427430
def logp(value):
428431
"""
@@ -540,6 +543,12 @@ def dist(cls, mu=0, sigma=None, tau=None, sd=None, no_assert=False, **kwargs):
540543

541544
return super().dist([mu, sigma], **kwargs)
542545

546+
def get_moment(rv, size, mu, sigma):
547+
mu, _ = at.broadcast_arrays(mu, sigma)
548+
if not rv_size_is_none(size):
549+
mu = at.full(size, mu)
550+
return mu
551+
543552
def logcdf(value, mu, sigma):
544553
"""
545554
Compute the log of the cumulative distribution function for Normal distribution
@@ -560,9 +569,6 @@ def logcdf(value, mu, sigma):
560569
0 < sigma,
561570
)
562571

563-
def get_moment(value_var, size, mu, sigma):
564-
return at.full(size, mu, dtype=aesara.config.floatX)
565-
566572

567573
class TruncatedNormalRV(RandomVariable):
568574
name = "truncated_normal"
@@ -691,19 +697,35 @@ def dist(
691697
assert_negative_support(sigma, "sigma", "TruncatedNormal")
692698
assert_negative_support(tau, "tau", "TruncatedNormal")
693699

694-
# if lower is None and upper is None:
695-
# initval = mu
696-
# elif lower is None and upper is not None:
697-
# initval = upper - 1.0
698-
# elif lower is not None and upper is None:
699-
# initval = lower + 1.0
700-
# else:
701-
# initval = (lower + upper) / 2
702-
703700
lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf)
704701
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)
705702
return super().dist([mu, sigma, lower, upper], **kwargs)
706703

704+
def get_moment(rv, size, mu, sigma, lower, upper):
705+
mu, _, lower, upper = at.broadcast_arrays(mu, sigma, lower, upper)
706+
moment = at.switch(
707+
at.isinf(lower),
708+
at.switch(
709+
at.isinf(upper),
710+
# lower = -inf, upper = inf
711+
mu,
712+
# lower = -inf, upper = x
713+
upper - 1,
714+
),
715+
at.switch(
716+
at.isinf(upper),
717+
# lower = x, upper = inf
718+
lower + 1,
719+
# lower = x, upper = x
720+
(lower + upper) / 2,
721+
),
722+
)
723+
724+
if not rv_size_is_none(size):
725+
moment = at.full(size, moment)
726+
727+
return moment
728+
707729
def logp(
708730
value,
709731
mu: Union[float, np.ndarray, TensorVariable],
@@ -828,6 +850,12 @@ def dist(cls, sigma=None, tau=None, sd=None, *args, **kwargs):
828850

829851
return super().dist([0.0, sigma], **kwargs)
830852

853+
def get_moment(rv, size, loc, sigma):
854+
moment = loc + sigma
855+
if not rv_size_is_none(size):
856+
moment = at.full(size, moment)
857+
return moment
858+
831859
def logcdf(value, loc, sigma):
832860
"""
833861
Compute the log of the cumulative distribution function for HalfNormal distribution
@@ -850,9 +878,6 @@ def logcdf(value, loc, sigma):
850878
0 < sigma,
851879
)
852880

853-
def _distr_parameters_for_repr(self):
854-
return ["sigma"]
855-
856881

857882
class WaldRV(RandomVariable):
858883
name = "wald"

pymc/distributions/discrete.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
)
4444
from pymc.distributions.distribution import Discrete
4545
from pymc.distributions.logprob import _logcdf
46+
from pymc.distributions.shape_utils import rv_size_is_none
4647
from pymc.math import sigmoid
4748

4849
__all__ = [
@@ -352,6 +353,11 @@ def dist(cls, p=None, logit_p=None, *args, **kwargs):
352353
p = at.as_tensor_variable(floatX(p))
353354
return super().dist([p], **kwargs)
354355

356+
def get_moment(rv, size, p):
357+
if not rv_size_is_none(size):
358+
p = at.full(size, p)
359+
return at.switch(p < 0.5, 0, 1)
360+
355361
def logp(value, p):
356362
r"""
357363
Calculate log-probability of Bernoulli distribution at specified value.
@@ -402,13 +408,6 @@ def logcdf(value, p):
402408
p <= 1,
403409
)
404410

405-
def get_moment(value, size, p):
406-
p = at.full(size, p)
407-
return at.switch(p < 0.5, at.zeros_like(value), at.ones_like(value))
408-
409-
def _distr_parameters_for_repr(self):
410-
return ["p"]
411-
412411

413412
class DiscreteWeibullRV(RandomVariable):
414413
name = "discrete_weibull"

pymc/distributions/distribution.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Callable, Optional, Sequence
2323

2424
import aesara
25-
import numpy as np
2625

2726
from aeppl.logprob import _logprob
2827
from aesara.tensor.basic import as_tensor_variable
@@ -371,7 +370,7 @@ def get_moment(rv: TensorVariable) -> TensorVariable:
371370
for which the value is to be derived.
372371
"""
373372
size = rv.owner.inputs[1]
374-
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:])
373+
return _get_moment(rv.owner.op, rv, size, *rv.owner.inputs[3:]).astype(rv.dtype)
375374

376375

377376
class Discrete(Distribution):

pymc/distributions/shape_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import numpy as np
2626

27-
from aesara.graph.basic import Variable
27+
from aesara.graph.basic import Constant, Variable
2828
from aesara.tensor.var import TensorVariable
2929

3030
from pymc.aesaraf import change_rv_size, pandas_to_array
@@ -37,6 +37,7 @@
3737
"get_broadcastable_dist_samples",
3838
"broadcast_distribution_samples",
3939
"broadcast_dist_samples_to",
40+
"rv_size_is_none",
4041
]
4142

4243

@@ -674,3 +675,8 @@ def maybe_resize(
674675
)
675676

676677
return rv_out
678+
679+
680+
def rv_size_is_none(size: Variable) -> bool:
681+
"""Check wether an rv size is None (ie., at.Constant([]))"""
682+
return isinstance(size, Constant) and size.data.size == 0
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform
5+
from pymc.distributions import HalfNormal
6+
from pymc.distributions.shape_utils import rv_size_is_none
7+
from pymc.initial_point import make_initial_point_fn
8+
from pymc.model import Model
9+
10+
11+
def test_rv_size_is_none():
12+
rv = Normal.dist(0, 1, size=None)
13+
assert rv_size_is_none(rv.owner.inputs[1])
14+
15+
rv = Normal.dist(0, 1, size=1)
16+
assert not rv_size_is_none(rv.owner.inputs[1])
17+
18+
size = Bernoulli.dist(0.5)
19+
rv = Normal.dist(0, 1, size=size)
20+
assert not rv_size_is_none(rv.owner.inputs[1])
21+
22+
size = Normal.dist(0, 1).size
23+
rv = Normal.dist(0, 1, size=size)
24+
assert not rv_size_is_none(rv.owner.inputs[1])
25+
26+
27+
def assert_moment_is_expected(model, expected):
28+
fn = make_initial_point_fn(
29+
model=model,
30+
return_transformed=False,
31+
default_strategy="moment",
32+
)
33+
result = fn(0)["x"]
34+
expected = np.asarray(expected)
35+
try:
36+
random_draw = model["x"].eval()
37+
except NotImplementedError:
38+
random_draw = result
39+
assert result.shape == expected.shape == random_draw.shape
40+
assert np.allclose(result, expected)
41+
42+
43+
@pytest.mark.parametrize(
44+
"size, expected",
45+
[
46+
(None, 0),
47+
(5, np.zeros(5)),
48+
((2, 5), np.zeros((2, 5))),
49+
],
50+
)
51+
def test_flat_moment(size, expected):
52+
with Model() as model:
53+
Flat("x", size=size)
54+
assert_moment_is_expected(model, expected)
55+
56+
57+
@pytest.mark.parametrize(
58+
"size, expected",
59+
[
60+
(None, 1),
61+
(5, np.ones(5)),
62+
((2, 5), np.ones((2, 5))),
63+
],
64+
)
65+
def test_halfflat_moment(size, expected):
66+
with Model() as model:
67+
HalfFlat("x", size=size)
68+
assert_moment_is_expected(model, expected)
69+
70+
71+
@pytest.mark.parametrize(
72+
"lower, upper, size, expected",
73+
[
74+
(-1, 1, None, 0),
75+
(-1, 1, 5, np.zeros(5)),
76+
(0, np.arange(1, 6), None, np.arange(1, 6) / 2),
77+
(0, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(1, 6) / 2)),
78+
],
79+
)
80+
def test_uniform_moment(lower, upper, size, expected):
81+
with Model() as model:
82+
Uniform("x", lower=lower, upper=upper, size=size)
83+
assert_moment_is_expected(model, expected)
84+
85+
86+
@pytest.mark.parametrize(
87+
"mu, sigma, size, expected",
88+
[
89+
(0, 1, None, 0),
90+
(0, np.ones(5), None, np.zeros(5)),
91+
(np.arange(5), 1, None, np.arange(5)),
92+
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
93+
],
94+
)
95+
def test_normal_moment(mu, sigma, size, expected):
96+
with Model() as model:
97+
Normal("x", mu=mu, sigma=sigma, size=size)
98+
assert_moment_is_expected(model, expected)
99+
100+
101+
@pytest.mark.parametrize(
102+
"sigma, size, expected",
103+
[
104+
(1, None, 1),
105+
(1, 5, np.ones(5)),
106+
(np.arange(5), None, np.arange(5)),
107+
(np.arange(5), (2, 5), np.full((2, 5), np.arange(5))),
108+
],
109+
)
110+
def test_halfnormal_moment(sigma, size, expected):
111+
with Model() as model:
112+
HalfNormal("x", sigma=sigma, size=size)
113+
assert_moment_is_expected(model, expected)
114+
115+
116+
@pytest.mark.skip(reason="aeppl interval transform fails when both edges are None")
117+
@pytest.mark.parametrize(
118+
"mu, sigma, lower, upper, size, expected",
119+
[
120+
(0.9, 1, -1, 1, None, 0),
121+
(0.9, 1, -np.inf, np.inf, 5, np.full(5, 0.9)),
122+
(np.arange(5), 1, None, 10, (2, 5), np.full((2, 5), 9)),
123+
(1, np.ones(5), -10, np.inf, None, np.full((2, 5), -9)),
124+
],
125+
)
126+
def test_truncatednormal_moment(mu, sigma, lower, upper, size, expected):
127+
with Model() as model:
128+
TruncatedNormal("x", mu=mu, sigma=sigma, lower=lower, upper=upper, size=size)
129+
assert_moment_is_expected(model, expected)
130+
131+
132+
@pytest.mark.parametrize(
133+
"p, size, expected",
134+
[
135+
(0.3, None, 0),
136+
(0.9, 5, np.ones(5)),
137+
(np.linspace(0, 1, 4), None, [0, 0, 1, 1]),
138+
(np.linspace(0, 1, 4), (2, 4), np.full((2, 4), [0, 0, 1, 1])),
139+
],
140+
)
141+
def test_bernoulli_moment(p, size, expected):
142+
with Model() as model:
143+
Bernoulli("x", p=p, size=size)
144+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)