Skip to content

Commit c22859d

Browse files
Adds moments for inverse gamma distribution (#5199)
Adds moments for inverse gamma distribution, returns the mode when the mean is undefined Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 90423a9 commit c22859d

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

pymc/distributions/continuous.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,23 +2334,19 @@ def dist(cls, alpha=None, beta=None, mu=None, sigma=None, sd=None, *args, **kwar
23342334
alpha = at.as_tensor_variable(floatX(alpha))
23352335
beta = at.as_tensor_variable(floatX(beta))
23362336

2337-
# m = beta / (alpha - 1.0)
2338-
# try:
2339-
# mean = (alpha > 1) * m or np.inf
2340-
# except ValueError: # alpha is an array
2341-
# m[alpha <= 1] = np.inf
2342-
# mean = m
2343-
2344-
# mode = beta / (alpha + 1.0)
2345-
# variance = at.switch(
2346-
# at.gt(alpha, 2), (beta ** 2) / ((alpha - 2) * (alpha - 1.0) ** 2), np.inf
2347-
# )
2348-
23492337
assert_negative_support(alpha, "alpha", "InverseGamma")
23502338
assert_negative_support(beta, "beta", "InverseGamma")
23512339

23522340
return super().dist([alpha, beta], **kwargs)
23532341

2342+
def get_moment(rv, size, alpha, beta):
2343+
mean = beta / (alpha - 1.0)
2344+
mode = beta / (alpha + 1.0)
2345+
moment = at.switch(alpha > 1, mean, mode)
2346+
if not rv_size_is_none(size):
2347+
moment = at.full(size, moment)
2348+
return moment
2349+
23542350
@classmethod
23552351
def _get_alpha_beta(cls, alpha, beta, mu, sigma):
23562352
if alpha is not None:

pymc/tests/test_distributions_moments.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
HalfNormal,
3232
HalfStudentT,
3333
HyperGeometric,
34+
InverseGamma,
3435
Kumaraswamy,
3536
Laplace,
3637
Logistic,
@@ -396,6 +397,21 @@ def test_gamma_moment(alpha, beta, size, expected):
396397
assert_moment_is_expected(model, expected)
397398

398399

400+
@pytest.mark.parametrize(
401+
"alpha, beta, size, expected",
402+
[
403+
(5, 1, None, 1 / 4),
404+
(0.5, 1, None, 1 / 1.5),
405+
(5, 1, 5, np.full(5, 1 / (5 - 1))),
406+
(np.arange(1, 6), 1, None, np.array([0.5, 1, 1 / 2, 1 / 3, 1 / 4])),
407+
],
408+
)
409+
def test_inverse_gamma_moment(alpha, beta, size, expected):
410+
with Model() as model:
411+
InverseGamma("x", alpha=alpha, beta=beta, size=size)
412+
assert_moment_is_expected(model, expected)
413+
414+
399415
@pytest.mark.parametrize(
400416
"alpha, m, size, expected",
401417
[

0 commit comments

Comments
 (0)