Skip to content

Fixed Interpolated Distribution's interval transform initialization #5067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3650,6 +3650,17 @@ class Interpolated(BoundedContinuous):

rv_op = interpolated

def __new__(cls, *args, **kwargs):
transform = kwargs.get("transform", UNSET)
if transform is UNSET:

def transform_params(*params):
_, _, _, x_points, _, _ = params
return floatX(x_points[0]), floatX(x_points[-1])

kwargs["transform"] = transforms.interval(transform_params)
return super().__new__(cls, *args, **kwargs)

@classmethod
def dist(cls, x_points, pdf_points, *args, **kwargs):

Expand Down
2 changes: 1 addition & 1 deletion pymc/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def grad_op(self):

def perform(self, node, inputs, output_storage):
(x,) = inputs
output_storage[0][0] = np.asarray(self.spline(x))
output_storage[0][0] = np.asarray(self.spline(x), dtype=x.dtype)

def grad(self, inputs, grads):
(x,) = inputs
Expand Down
18 changes: 17 additions & 1 deletion pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import numpy as np
import numpy.random as nr

from pymc.util import UNSET

try:
from polyagamma import polyagamma_cdf, polyagamma_pdf

Expand Down Expand Up @@ -2667,7 +2669,6 @@ def test_moyal_logcdf(self):
if aesara.config.floatX == "float32":
raise Exception("Flaky test: It passed this time, but XPASS is not allowed.")

@pytest.mark.skipif(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
def test_interpolated(self):
for mu in R.vals:
for sigma in Rplus.vals:
Expand Down Expand Up @@ -2695,6 +2696,21 @@ def ref_pdf(value):

self.check_logp(TestedInterpolated, R, {}, ref_pdf)

@pytest.mark.parametrize("transform", [UNSET, None])
def test_interpolated_transform(self, transform):
# Issue: https://github.com/pymc-devs/pymc/issues/5048
x_points = np.linspace(0, 10, 10)
pdf_points = sp.norm.pdf(x_points, loc=1, scale=1)
with pm.Model() as m:
x = pm.Interpolated("x", x_points, pdf_points, transform=transform)

if transform is UNSET:
assert np.isfinite(m.logp({"x_interval__": -1.0}))
assert np.isfinite(m.logp({"x_interval__": 11.0}))
else:
assert not np.isfinite(m.logp({"x": -1.0}))
assert not np.isfinite(m.logp({"x": 11.0}))


class TestBound:
"""Tests for pm.Bound distribution"""
Expand Down