Skip to content

Commit 073e26b

Browse files
authored
Added moments for gumbel, triangular and logitnormal distributions for issue #5078 (#5180)
1 parent 8d1708a commit 073e26b

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

pymc/distributions/continuous.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def polyagamma_cdf(*args, **kwargs):
8686
)
8787
from pymc.distributions.distribution import Continuous
8888
from pymc.distributions.shape_utils import rv_size_is_none
89-
from pymc.math import logdiffexp, logit
89+
from pymc.math import invlogit, logdiffexp, logit
9090
from pymc.util import UNSET
9191

9292
__all__ = [
@@ -3101,6 +3101,12 @@ def dist(cls, lower=0, upper=1, c=0.5, *args, **kwargs):
31013101

31023102
return super().dist([lower, c, upper], *args, **kwargs)
31033103

3104+
def get_moment(rv, size, lower, c, upper):
3105+
mean = (lower + upper + c) / 3
3106+
if not rv_size_is_none(size):
3107+
mean = at.full(size, mean)
3108+
return mean
3109+
31043110
def logcdf(value, lower, c, upper):
31053111
"""
31063112
Compute the log of the cumulative distribution function for Triangular distribution
@@ -3198,6 +3204,12 @@ def dist(
31983204

31993205
return super().dist([mu, beta], **kwargs)
32003206

3207+
def get_moment(rv, size, mu, beta):
3208+
mean = mu + beta * np.euler_gamma
3209+
if not rv_size_is_none(size):
3210+
mean = at.full(size, mean)
3211+
return mean
3212+
32013213
def _distr_parameters_for_repr(self):
32023214
return ["mu", "beta"]
32033215

@@ -3501,6 +3513,12 @@ def dist(cls, mu=0, sigma=None, tau=None, sd=None, **kwargs):
35013513

35023514
return super().dist([mu, sigma], **kwargs)
35033515

3516+
def get_moment(rv, size, mu, sigma):
3517+
median, _ = at.broadcast_arrays(invlogit(mu), sigma)
3518+
if not rv_size_is_none(size):
3519+
median = at.full(size, median)
3520+
return median
3521+
35043522
def logp(value, mu, sigma):
35053523
"""
35063524
Calculate log-probability of LogitNormal distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Flat,
1818
Gamma,
1919
Geometric,
20+
Gumbel,
2021
HalfCauchy,
2122
HalfFlat,
2223
HalfNormal,
@@ -25,12 +26,14 @@
2526
Kumaraswamy,
2627
Laplace,
2728
Logistic,
29+
LogitNormal,
2830
LogNormal,
2931
NegativeBinomial,
3032
Normal,
3133
Pareto,
3234
Poisson,
3335
StudentT,
36+
Triangular,
3437
TruncatedNormal,
3538
Uniform,
3639
Wald,
@@ -650,3 +653,62 @@ def test_dirichlet_moment(a, size, expected):
650653
with Model() as model:
651654
Dirichlet("x", a=a, size=size)
652655
assert_moment_is_expected(model, expected)
656+
657+
658+
@pytest.mark.parametrize(
659+
"mu, beta, size, expected",
660+
[
661+
(0, 2, None, 2 * np.euler_gamma),
662+
(1, np.arange(1, 4), None, 1 + np.arange(1, 4) * np.euler_gamma),
663+
(np.arange(5), 2, None, np.arange(5) + 2 * np.euler_gamma),
664+
(1, 2, 5, np.full(5, 1 + 2 * np.euler_gamma)),
665+
(
666+
np.arange(5),
667+
np.arange(1, 6),
668+
(2, 5),
669+
np.full((2, 5), np.arange(5) + np.arange(1, 6) * np.euler_gamma),
670+
),
671+
],
672+
)
673+
def test_gumbel_moment(mu, beta, size, expected):
674+
with Model() as model:
675+
Gumbel("x", mu=mu, beta=beta, size=size)
676+
assert_moment_is_expected(model, expected)
677+
678+
679+
@pytest.mark.parametrize(
680+
"c, lower, upper, size, expected",
681+
[
682+
(1, 0, 5, None, 2),
683+
(3, np.arange(-3, 6, 3), np.arange(3, 12, 3), None, np.array([1, 3, 5])),
684+
(np.arange(-3, 6, 3), -3, 3, None, np.array([-1, 0, 1])),
685+
(3, -3, 6, 5, np.full(5, 2)),
686+
(
687+
np.arange(-3, 6, 3),
688+
np.arange(-9, -2, 3),
689+
np.arange(3, 10, 3),
690+
(2, 3),
691+
np.full((2, 3), np.array([-3, 0, 3])),
692+
),
693+
],
694+
)
695+
def test_triangular_moment(c, lower, upper, size, expected):
696+
with Model() as model:
697+
Triangular("x", c=c, lower=lower, upper=upper, size=size)
698+
assert_moment_is_expected(model, expected)
699+
700+
701+
@pytest.mark.parametrize(
702+
"mu, sigma, size, expected",
703+
[
704+
(1, 2, None, special.expit(1)),
705+
(0, np.arange(1, 5), None, special.expit(np.zeros(4))),
706+
(np.arange(4), 1, None, special.expit(np.arange(4))),
707+
(1, 5, 4, special.expit(np.ones(4))),
708+
(np.arange(4), np.arange(1, 5), (2, 4), np.full((2, 4), special.expit(np.arange(4)))),
709+
],
710+
)
711+
def test_logitnormal_moment(mu, sigma, size, expected):
712+
with Model() as model:
713+
LogitNormal("x", mu=mu, sigma=sigma, size=size)
714+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)