Skip to content

Commit a099292

Browse files
Add back some missing moments (#5147)
* Exponential * Laplace * StudentT * Cauchy * Kumaraswamy
1 parent e11969a commit a099292

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,12 @@ def dist(cls, a, b, *args, **kwargs):
13131313

13141314
return super().dist([a, b], *args, **kwargs)
13151315

1316+
def get_moment(rv, size, a, b):
1317+
mean = at.exp(at.log(b) + at.gammaln(1 + 1 / a) + at.gammaln(b) - at.gammaln(1 + 1 / a + b))
1318+
if not rv_size_is_none(size):
1319+
mean = at.full(size, mean)
1320+
return mean
1321+
13161322
def logp(value, a, b):
13171323
"""
13181324
Calculate log-probability of Kumaraswamy distribution at specified value.
@@ -1399,6 +1405,11 @@ def dist(cls, lam, *args, **kwargs):
13991405
# Aesara exponential op is parametrized in terms of mu (1/lam)
14001406
return super().dist([at.inv(lam)], **kwargs)
14011407

1408+
def get_moment(rv, size, mu):
1409+
if not rv_size_is_none(size):
1410+
mu = at.full(size, mu)
1411+
return mu
1412+
14021413
def logcdf(value, mu):
14031414
r"""
14041415
Compute the log of cumulative distribution function for the Exponential distribution
@@ -1475,6 +1486,12 @@ def dist(cls, mu, b, *args, **kwargs):
14751486
assert_negative_support(b, "b", "Laplace")
14761487
return super().dist([mu, b], *args, **kwargs)
14771488

1489+
def get_moment(rv, size, mu, b):
1490+
mu, _ = at.broadcast_arrays(mu, b)
1491+
if not rv_size_is_none(size):
1492+
mu = at.full(size, mu)
1493+
return mu
1494+
14781495
def logcdf(value, mu, b):
14791496
"""
14801497
Compute the log of the cumulative distribution function for Laplace distribution
@@ -1800,6 +1817,12 @@ def dist(cls, nu, mu=0, lam=None, sigma=None, sd=None, *args, **kwargs):
18001817

18011818
return super().dist([nu, mu, sigma], **kwargs)
18021819

1820+
def get_moment(rv, size, nu, mu, sigma):
1821+
mu, _, _ = at.broadcast_arrays(mu, nu, sigma)
1822+
if not rv_size_is_none(size):
1823+
mu = at.full(size, mu)
1824+
return mu
1825+
18031826
def logp(value, nu, mu, sigma):
18041827
"""
18051828
Calculate log-probability of StudentT distribution at specified value.
@@ -2001,12 +2024,15 @@ def dist(cls, alpha, beta, *args, **kwargs):
20012024
alpha = at.as_tensor_variable(floatX(alpha))
20022025
beta = at.as_tensor_variable(floatX(beta))
20032026

2004-
# median = alpha
2005-
# mode = alpha
2006-
20072027
assert_negative_support(beta, "beta", "Cauchy")
20082028
return super().dist([alpha, beta], **kwargs)
20092029

2030+
def get_moment(rv, size, alpha, beta):
2031+
alpha, _ = at.broadcast_arrays(alpha, beta)
2032+
if not rv_size_is_none(size):
2033+
alpha = at.full(size, alpha)
2034+
return alpha
2035+
20102036
def logcdf(value, alpha, beta):
20112037
"""
20122038
Compute the log of the cumulative distribution function for Cauchy distribution

pymc/tests/test_distributions_moments.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
import pytest
33

44
from pymc import Bernoulli, Flat, HalfFlat, Normal, TruncatedNormal, Uniform
5-
from pymc.distributions import Beta, HalfNormal
5+
from pymc.distributions import (
6+
Beta,
7+
Cauchy,
8+
Exponential,
9+
HalfNormal,
10+
Kumaraswamy,
11+
Laplace,
12+
StudentT,
13+
)
614
from pymc.distributions.shape_utils import rv_size_is_none
715
from pymc.initial_point import make_initial_point_fn
816
from pymc.model import Model
@@ -157,3 +165,79 @@ def test_beta_moment(alpha, beta, size, expected):
157165
with Model() as model:
158166
Beta("x", alpha=alpha, beta=beta, size=size)
159167
assert_moment_is_expected(model, expected)
168+
169+
170+
@pytest.mark.parametrize(
171+
"lam, size, expected",
172+
[
173+
(2, None, 0.5),
174+
(2, 5, np.full(5, 0.5)),
175+
(np.arange(1, 5), None, 1 / np.arange(1, 5)),
176+
(np.arange(1, 5), (2, 4), np.full((2, 4), 1 / np.arange(1, 5))),
177+
],
178+
)
179+
def test_exponential_moment(lam, size, expected):
180+
with Model() as model:
181+
Exponential("x", lam=lam, size=size)
182+
assert_moment_is_expected(model, expected)
183+
184+
185+
@pytest.mark.parametrize(
186+
"mu, b, size, expected",
187+
[
188+
(0, 1, None, 0),
189+
(0, np.ones(5), None, np.zeros(5)),
190+
(np.arange(5), 1, None, np.arange(5)),
191+
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
192+
],
193+
)
194+
def test_laplace_moment(mu, b, size, expected):
195+
with Model() as model:
196+
Laplace("x", mu=mu, b=b, size=size)
197+
assert_moment_is_expected(model, expected)
198+
199+
200+
@pytest.mark.parametrize(
201+
"mu, nu, sigma, size, expected",
202+
[
203+
(0, 1, 1, None, 0),
204+
(0, np.ones(5), 1, None, np.zeros(5)),
205+
(np.arange(5), 10, np.arange(1, 6), None, np.arange(5)),
206+
(np.arange(5), 10, np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
207+
],
208+
)
209+
def test_studentt_moment(mu, nu, sigma, size, expected):
210+
with Model() as model:
211+
StudentT("x", mu=mu, nu=nu, sigma=sigma, size=size)
212+
assert_moment_is_expected(model, expected)
213+
214+
215+
@pytest.mark.parametrize(
216+
"alpha, beta, size, expected",
217+
[
218+
(0, 1, None, 0),
219+
(0, np.ones(5), None, np.zeros(5)),
220+
(np.arange(5), 1, None, np.arange(5)),
221+
(np.arange(5), np.arange(1, 6), (2, 5), np.full((2, 5), np.arange(5))),
222+
],
223+
)
224+
def test_cauchy_moment(alpha, beta, size, expected):
225+
with Model() as model:
226+
Cauchy("x", alpha=alpha, beta=beta, size=size)
227+
assert_moment_is_expected(model, expected)
228+
229+
230+
@pytest.mark.parametrize(
231+
"a, b, size, expected",
232+
[
233+
(1, 1, None, 0.5),
234+
(1, 1, 5, np.full(5, 0.5)),
235+
(1, np.arange(1, 6), None, 1 / np.arange(2, 7)),
236+
(np.arange(1, 6), 1, None, np.arange(1, 6) / np.arange(2, 7)),
237+
(1, np.arange(1, 6), (2, 5), np.full((2, 5), 1 / np.arange(2, 7))),
238+
],
239+
)
240+
def test_kumaraswamy_moment(a, b, size, expected):
241+
with Model() as model:
242+
Kumaraswamy("x", a=a, b=b, size=size)
243+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)