Skip to content

Commit 5107dae

Browse files
kc611ricardoV94
authored andcommitted
Fix Interpolated Distribution interval transform initialization
This also fixes a previous failure on float32
1 parent 99ec0ff commit 5107dae

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

pymc/distributions/continuous.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,6 +3650,17 @@ class Interpolated(BoundedContinuous):
36503650

36513651
rv_op = interpolated
36523652

3653+
def __new__(cls, *args, **kwargs):
3654+
transform = kwargs.get("transform", UNSET)
3655+
if transform is UNSET:
3656+
3657+
def transform_params(*params):
3658+
_, _, _, x_points, _, _ = params
3659+
return floatX(x_points[0]), floatX(x_points[-1])
3660+
3661+
kwargs["transform"] = transforms.interval(transform_params)
3662+
return super().__new__(cls, *args, **kwargs)
3663+
36533664
@classmethod
36543665
def dist(cls, x_points, pdf_points, *args, **kwargs):
36553666

pymc/distributions/dist_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def grad_op(self):
341341

342342
def perform(self, node, inputs, output_storage):
343343
(x,) = inputs
344-
output_storage[0][0] = np.asarray(self.spline(x))
344+
output_storage[0][0] = np.asarray(self.spline(x), dtype=x.dtype)
345345

346346
def grad(self, inputs, grads):
347347
(x,) = inputs

pymc/tests/test_distributions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import numpy as np
2121
import numpy.random as nr
2222

23+
from pymc.util import UNSET
24+
2325
try:
2426
from polyagamma import polyagamma_cdf, polyagamma_pdf
2527

@@ -2667,7 +2669,6 @@ def test_moyal_logcdf(self):
26672669
if aesara.config.floatX == "float32":
26682670
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")
26692671

2670-
@pytest.mark.skipif(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
26712672
def test_interpolated(self):
26722673
for mu in R.vals:
26732674
for sigma in Rplus.vals:
@@ -2695,6 +2696,21 @@ def ref_pdf(value):
26952696

26962697
self.check_logp(TestedInterpolated, R, {}, ref_pdf)
26972698

2699+
@pytest.mark.parametrize("transform", [UNSET, None])
2700+
def test_interpolated_transform(self, transform):
2701+
# Issue: https://github.com/pymc-devs/pymc/issues/5048
2702+
x_points = np.linspace(0, 10, 10)
2703+
pdf_points = sp.norm.pdf(x_points, loc=1, scale=1)
2704+
with pm.Model() as m:
2705+
x = pm.Interpolated("x", x_points, pdf_points, transform=transform)
2706+
2707+
if transform is UNSET:
2708+
assert np.isfinite(m.logp({"x_interval__": -1.0}))
2709+
assert np.isfinite(m.logp({"x_interval__": 11.0}))
2710+
else:
2711+
assert not np.isfinite(m.logp({"x": -1.0}))
2712+
assert not np.isfinite(m.logp({"x": 11.0}))
2713+
26982714

26992715
class TestBound:
27002716
"""Tests for pm.Bound distribution"""

0 commit comments

Comments
 (0)