Skip to content

Commit 6079ffe

Browse files
author
juan.lopez.arriaza
committed
Adding moments for AsymmetricLaplace and SkewNormal and corresponding tests
1 parent 275c145 commit 6079ffe

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,13 @@ def dist(cls, b, kappa, mu=0, *args, **kwargs):
15991599

16001600
return super().dist([b, kappa, mu], *args, **kwargs)
16011601

1602+
def get_moment(rv, size, b, kappa, mu):
1603+
mean = mu - (kappa - 1 / kappa) / b
1604+
1605+
if not rv_size_is_none(size):
1606+
mean = at.full(size, mean)
1607+
return mean
1608+
16021609
def logp(value, b, kappa, mu):
16031610
"""
16041611
Calculate log-probability of Asymmetric-Laplace distribution at specified value.
@@ -3012,6 +3019,12 @@ def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, sd=None, *args, **kwargs):
30123019

30133020
return super().dist([mu, sigma, alpha], *args, **kwargs)
30143021

3022+
def get_moment(rv, size, mu, sigma, alpha):
3023+
mean = mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha ** 2) ** 0.5
3024+
if not rv_size_is_none(size):
3025+
mean = at.full(size, mean)
3026+
return mean
3027+
30153028
def logp(value, mu, sigma, alpha):
30163029
"""
30173030
Calculate log-probability of SkewNormal distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy import special
55

66
from pymc.distributions import (
7+
AsymmetricLaplace,
78
Bernoulli,
89
Beta,
910
Binomial,
@@ -29,6 +30,7 @@
2930
Normal,
3031
Pareto,
3132
Poisson,
33+
SkewNormal,
3234
StudentT,
3335
TruncatedNormal,
3436
Uniform,
@@ -612,3 +614,59 @@ def test_discrete_uniform_moment(lower, upper, size, expected):
612614
with Model() as model:
613615
DiscreteUniform("x", lower=lower, upper=upper, size=size)
614616
assert_moment_is_expected(model, expected)
617+
618+
619+
@pytest.mark.parametrize(
620+
"alpha, mu, sigma, size, expected",
621+
[
622+
(1.0, 1.0, 1.0, None, 1.56418958),
623+
(1, np.ones(5), 1, None, np.full(5, 1.56418958)),
624+
(np.ones(5), 1, np.ones(5), None, np.full(5, 1.56418958)),
625+
(
626+
np.arange(5),
627+
np.arange(1, 6),
628+
np.arange(1, 6),
629+
None,
630+
(1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861),
631+
),
632+
(
633+
np.arange(5),
634+
np.arange(1, 6),
635+
np.arange(1, 6),
636+
(2, 5),
637+
np.full((2, 5), (1.0, 3.12837917, 5.14094894, 7.02775903, 8.87030861)),
638+
),
639+
],
640+
)
641+
def test_skewnormal_moment(alpha, mu, sigma, size, expected):
642+
with Model() as model:
643+
SkewNormal("x", alpha=alpha, mu=mu, sigma=sigma, size=size)
644+
assert_moment_is_expected(model, expected)
645+
646+
647+
@pytest.mark.parametrize(
648+
"b, kappa, mu, size, expected",
649+
[
650+
(1.0, 1.0, 1.0, None, 1.0),
651+
(1.0, np.ones(5), 1.0, None, np.full(5, 1.0)),
652+
(np.arange(1, 6), 1.0, np.ones(5), None, np.full(5, 1.0)),
653+
(
654+
np.arange(1, 6),
655+
np.arange(1, 6),
656+
np.arange(1, 6),
657+
None,
658+
(1.0, 1.25, 2.111111111111111, 3.0625, 4.04),
659+
),
660+
(
661+
np.arange(1, 6),
662+
np.arange(1, 6),
663+
np.arange(1, 6),
664+
(2, 5),
665+
np.full((2, 5), (1.0, 1.25, 2.111111111111111, 3.0625, 4.04)),
666+
),
667+
],
668+
)
669+
def test_asymmetriclaplace_moment(b, kappa, mu, size, expected):
670+
with Model() as model:
671+
AsymmetricLaplace("x", b=b, kappa=kappa, mu=mu, size=size)
672+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)