Skip to content

Commit 45fd9c8

Browse files
Adding Interpolated moment (#5222)
* Added Interpolated moments and tests Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f32c039 commit 45fd9c8

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,15 @@ def dist(cls, x_points, pdf_points, *args, **kwargs):
36863686

36873687
return super().dist([x_points, pdf_points, cdf_points], **kwargs)
36883688

3689+
def get_moment(rv, size, x_points, pdf_points, cdf_points):
3690+
# cdf_points argument is unused
3691+
moment = at.sum(at.mul(x_points, pdf_points))
3692+
3693+
if not rv_size_is_none(size):
3694+
moment = at.full(size, moment)
3695+
3696+
return moment
3697+
36893698
def logp(value, x_points, pdf_points, cdf_points):
36903699
"""
36913700
Calculate log-probability of Interpolated distribution at specified value.

pymc/tests/test_distributions_moments.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
HalfNormal,
3434
HalfStudentT,
3535
HyperGeometric,
36+
Interpolated,
3637
InverseGamma,
3738
Kumaraswamy,
3839
Laplace,
@@ -800,6 +801,36 @@ def test_categorical_moment(p, size, expected):
800801
assert_moment_is_expected(model, expected)
801802

802803

804+
@pytest.mark.parametrize(
805+
"x_points, pdf_points, size, expected",
806+
[
807+
(np.array([-1, 1]), np.array([0.4, 0.6]), None, 0.2),
808+
(
809+
np.array([-4, -1, 3, 9, 19]),
810+
np.array([0.1, 0.15, 0.2, 0.25, 0.3]),
811+
None,
812+
1.5458937198067635,
813+
),
814+
(
815+
np.array([-22, -4, 0, 8, 13]),
816+
np.tile(1 / 5, 5),
817+
(5, 3),
818+
np.full((5, 3), -0.14285714285714296),
819+
),
820+
(
821+
np.arange(-100, 10),
822+
np.arange(1, 111) / 6105,
823+
(2, 5, 3),
824+
np.full((2, 5, 3), -27.584097859327223),
825+
),
826+
],
827+
)
828+
def test_interpolated_moment(x_points, pdf_points, size, expected):
829+
with Model() as model:
830+
Interpolated("x", x_points=x_points, pdf_points=pdf_points, size=size)
831+
assert_moment_is_expected(model, expected)
832+
833+
803834
@pytest.mark.parametrize(
804835
"mu, cov, size, expected",
805836
[

0 commit comments

Comments
 (0)