Skip to content

Commit 64de50e

Browse files
Fixed tests - should work?
1 parent 7b10f88 commit 64de50e

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3689,7 +3689,13 @@ def dist(cls, x_points, pdf_points, *args, **kwargs):
36893689
def get_moment(rv, size, x_points, pdf_points, cdf_points):
36903690
# cdf_points argument is unused
36913691
# moment = at.as_tensor(0.)
3692+
print(x_points)
3693+
print("")
3694+
for pdf in pdf_points.eval():
3695+
print(pdf)
3696+
print("")
36923697
moment = at.sum(at.mul(x_points, pdf_points))
3698+
print(moment.eval())
36933699
if not rv_size_is_none(size):
36943700
moment = at.full(size, moment)
36953701

pymc/initial_point.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def make_initial_point_expression(
272272
if isinstance(strategy, str):
273273
if strategy == "moment":
274274
value = get_moment(variable)
275+
print(value.eval())
275276
elif strategy == "prior":
276277
value = variable
277278
else:

pymc/tests/test_distributions_moments.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -805,14 +805,14 @@ def test_categorical_moment(p, size, expected):
805805
"x_points, pdf_points, size, expected",
806806
[
807807
(np.array([-1, 1]), np.array([0.4, 0.6]), None, 0.2),
808-
(np.array([-4, -1, 3, 9, 19]), np.array([0.1 , 0.15, 0.2 , 0.25, 0.3 ]), None, 8),
809-
# (np.array([-22, -4, 0, 8, 13]), np.tile(1 / 5, 5), (5, 3), -np.ones((5, 3))),
810-
# (
811-
# np.arange(-100, 10),
812-
# np.arange(1, 111) / 6105,
813-
# (2, 5, 3),
814-
# np.broadcast_to(-82 / 3, (2, 5, 3)),
815-
# ),
808+
(np.array([-4, -1, 3, 9, 19]), np.array([0.1 , 0.15, 0.2 , 0.25, 0.3]), None, 1.5458937198067635),
809+
(np.array([-22, -4, 0, 8, 13]), np.tile(1 / 5, 5), (5, 3), np.full((5, 3), -0.14285714285714296)),
810+
(
811+
np.arange(-100, 10),
812+
np.arange(1, 111) / 6105,
813+
(2, 5, 3),
814+
np.full((2, 5, 3), -27.584097859327223),
815+
),
816816
],
817817
)
818818
def test_interpolated_moment(x_points, pdf_points, size, expected):

0 commit comments

Comments
 (0)