Skip to content

Commit 4da614d

Browse files
authored
Add ChiSquared moment (#5154)
1 parent 777622a commit 4da614d

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,6 +2427,12 @@ def dist(cls, nu, *args, **kwargs):
24272427
nu = at.as_tensor_variable(floatX(nu))
24282428
return super().dist([nu], *args, **kwargs)
24292429

2430+
def get_moment(rv, size, nu):
2431+
moment = nu
2432+
if not rv_size_is_none(size):
2433+
moment = at.full(size, moment)
2434+
return moment
2435+
24302436
def logcdf(value, nu):
24312437
"""
24322438
Compute the log of the cumulative distribution function for ChiSquared distribution

pymc/tests/test_distributions_moments.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pymc.distributions import (
88
Beta,
99
Cauchy,
10+
ChiSquared,
1011
Exponential,
1112
Gamma,
1213
HalfCauchy,
@@ -173,6 +174,20 @@ def test_beta_moment(alpha, beta, size, expected):
173174
assert_moment_is_expected(model, expected)
174175

175176

177+
@pytest.mark.parametrize(
178+
"nu, size, expected",
179+
[
180+
(1, None, 1),
181+
(1, 5, np.full(5, 1)),
182+
(np.arange(1, 6), None, np.arange(1, 6)),
183+
],
184+
)
185+
def test_chisquared_moment(nu, size, expected):
186+
with Model() as model:
187+
ChiSquared("x", nu=nu, size=size)
188+
assert_moment_is_expected(model, expected)
189+
190+
176191
@pytest.mark.parametrize(
177192
"lam, size, expected",
178193
[

0 commit comments

Comments
 (0)