From b040aa4db49cbc437cc6a2020de7afaab8ce83ee Mon Sep 17 00:00:00 2001 From: kc611 Date: Sun, 10 Oct 2021 10:09:28 +0530 Subject: [PATCH] Fix Interpolated Distribution interval transform initialization This also fixes a previous failure on float32 --- pymc/distributions/continuous.py | 11 +++++++++++ pymc/distributions/dist_math.py | 2 +- pymc/tests/test_distributions.py | 18 +++++++++++++++++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 8372df9ea9..97ad4035e1 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -3650,6 +3650,17 @@ class Interpolated(BoundedContinuous): rv_op = interpolated + def __new__(cls, *args, **kwargs): + transform = kwargs.get("transform", UNSET) + if transform is UNSET: + + def transform_params(*params): + _, _, _, x_points, _, _ = params + return floatX(x_points[0]), floatX(x_points[-1]) + + kwargs["transform"] = transforms.interval(transform_params) + return super().__new__(cls, *args, **kwargs) + @classmethod def dist(cls, x_points, pdf_points, *args, **kwargs): diff --git a/pymc/distributions/dist_math.py b/pymc/distributions/dist_math.py index 6a76b49427..9b4a2bb57b 100644 --- a/pymc/distributions/dist_math.py +++ b/pymc/distributions/dist_math.py @@ -341,7 +341,7 @@ def grad_op(self): def perform(self, node, inputs, output_storage): (x,) = inputs - output_storage[0][0] = np.asarray(self.spline(x)) + output_storage[0][0] = np.asarray(self.spline(x), dtype=x.dtype) def grad(self, inputs, grads): (x,) = inputs diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index ff18af7394..9cd1ef5424 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -20,6 +20,8 @@ import numpy as np import numpy.random as nr +from pymc.util import UNSET + try: from polyagamma import polyagamma_cdf, polyagamma_pdf @@ -2667,7 +2669,6 @@ def test_moyal_logcdf(self): if aesara.config.floatX == "float32": raise Exception("Flaky test: It passed this time, but XPASS is not allowed.") - @pytest.mark.skipif(condition=(aesara.config.floatX == "float32"), reason="Fails on float32") def test_interpolated(self): for mu in R.vals: for sigma in Rplus.vals: @@ -2695,6 +2696,21 @@ def ref_pdf(value): self.check_logp(TestedInterpolated, R, {}, ref_pdf) + @pytest.mark.parametrize("transform", [UNSET, None]) + def test_interpolated_transform(self, transform): + # Issue: https://github.com/pymc-devs/pymc/issues/5048 + x_points = np.linspace(0, 10, 10) + pdf_points = sp.norm.pdf(x_points, loc=1, scale=1) + with pm.Model() as m: + x = pm.Interpolated("x", x_points, pdf_points, transform=transform) + + if transform is UNSET: + assert np.isfinite(m.logp({"x_interval__": -1.0})) + assert np.isfinite(m.logp({"x_interval__": 11.0})) + else: + assert not np.isfinite(m.logp({"x": -1.0})) + assert not np.isfinite(m.logp({"x": 11.0})) + class TestBound: """Tests for pm.Bound distribution"""