Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,12 @@ def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def get_moment(rv, size, nu):
moment = nu
if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment

def logcdf(value, nu):
"""
Compute the log of the cumulative distribution function for ChiSquared distribution
Expand Down
15 changes: 15 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pymc.distributions import (
Beta,
Cauchy,
ChiSquared,
Exponential,
Gamma,
HalfCauchy,
Expand Down Expand Up @@ -173,6 +174,20 @@ def test_beta_moment(alpha, beta, size, expected):
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"nu, size, expected",
[
(1, None, 1),
(1, 5, np.full(5, 1)),
(np.arange(1, 6), None, np.arange(1, 6)),
],
)
def test_chisquared_moment(nu, size, expected):
with Model() as model:
ChiSquared("x", nu=nu, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize(
"lam, size, expected",
[
Expand Down