Skip to content

Commit acb3196

Browse files
juan.lopez.arriazaricardoV94
juan.lopez.arriaza
authored andcommitted
Adding PolyaGamma moment and corresponding tests
1 parent 90860e6 commit acb3196

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from aesara.graph.op import Op
3131
from aesara.tensor import gammaln
3232
from aesara.tensor.extra_ops import broadcast_shape
33+
from aesara.tensor.math import tanh
3334
from aesara.tensor.random.basic import (
3435
BetaRV,
3536
WeibullRV,
@@ -3985,6 +3986,12 @@ def dist(cls, h=1.0, z=0.0, **kwargs):
39853986

39863987
return super().dist([h, z], **kwargs)
39873988

3989+
def get_moment(rv, size, h, z):
3990+
mean = at.switch(at.eq(z, 0), h / 4, tanh(z / 2) * (h / (2 * z)))
3991+
if not rv_size_is_none(size):
3992+
mean = at.full(size, mean)
3993+
return mean
3994+
39883995
def logp(value, h, z):
39893996
"""
39903997
Calculate log-probability of Polya-Gamma distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Normal,
4242
Pareto,
4343
Poisson,
44+
PolyaGamma,
4445
Rice,
4546
SkewNormal,
4647
StudentT,
@@ -984,3 +985,54 @@ def _random(mu, rng=None, size=None):
984985
match="Cannot safely infer the size of a multivariate random variable's moment.",
985986
):
986987
evaled_moment = get_moment(a).eval({mu: mu_val})
988+
989+
990+
@pytest.mark.parametrize(
991+
"h, z, size, expected",
992+
[
993+
(1.0, 0.0, None, 0.25),
994+
(
995+
1.0,
996+
np.arange(5),
997+
None,
998+
(
999+
0.25,
1000+
0.23105857863000487,
1001+
0.1903985389889412,
1002+
0.1508580422741444,
1003+
0.12050344750947711,
1004+
),
1005+
),
1006+
(
1007+
np.arange(1, 6),
1008+
np.arange(5),
1009+
None,
1010+
(
1011+
0.25,
1012+
0.46211715726000974,
1013+
0.5711956169668236,
1014+
0.6034321690965776,
1015+
0.6025172375473855,
1016+
),
1017+
),
1018+
(
1019+
np.arange(1, 6),
1020+
np.arange(5),
1021+
(2, 5),
1022+
np.full(
1023+
(2, 5),
1024+
(
1025+
0.25,
1026+
0.46211715726000974,
1027+
0.5711956169668236,
1028+
0.6034321690965776,
1029+
0.6025172375473855,
1030+
),
1031+
),
1032+
),
1033+
],
1034+
)
1035+
def test_polyagamma_moment(h, z, size, expected):
1036+
with Model() as model:
1037+
PolyaGamma("x", h=h, z=z, size=size)
1038+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)