Skip to content

BUG: Unable to apply Interval foward transformation #7193

Closed
@tvwenger

Description

@tvwenger

Describe the issue:

Description

I wish to evaluate the log probability of a model at a specific point. Since some RVs may be transformed, I must apply the transformation to the input point in order to pass the transformed point to model.logp. If one of the RV transforms is Interval, then the call to forward() fails with an obscure error.

Expected Behavior

Applying forward to a RV transformed via Interval should not raise an exception.

Actual Behavior

Applying forward to a RV transformed via Interval raises an exception.

Minimum working example

The following example demonstrates what I want to do. I have a model where x is not transformed, y has a LogTransform, and z has an Interval transform. I want to evaluate logp at a given x, y, z. Since logp requires the transformed point in order to evaluate the expression, I must apply the transformations and pass the transformed point to logp().eval().

The output of this MWE is below:

5.10.4
name: x param: x transform: None
value: 1.0 transformed_value: 1.0
name: y param: y_log__ transform: LogTransform
value: 0.5 transformed_value: -0.6931471824645996
name: z param: z_interval__ transform: Interval
<Exception>

Reproduceable code example:

import pymc as pm

print(pm.__version__)

with pm.Model() as model:
    x = pm.Normal("x", mu=0.0, sigma=1.0)
    y = pm.LogNormal("y", mu=1.0, sigma=1.0)
    z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)

point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
    name = rv.name
    param = model.rvs_to_values[rv]
    transform = model.rvs_to_transforms[rv]
    print(f"name: {name} param: {param} transform: {transform}")
    if transform is None:
        point_transformed[param] = point[name]
    else:
        point_transformed[param] = transform.forward(point[name]).eval()
    print(f"value: {point[name]} transformed_value: {point_transformed[param]}")

print(f"Log prob: {model.logp().eval(point_transformed)}")

Error message:

IndexError                                Traceback (most recent call last)
Cell In[25], line 20
     18         point_transformed[param] = point[name]
     19     else:
---> 20         point_transformed[param] = transform.forward(point[name]).eval()
     21     print(f"value: {point[name]} transformed_value: {point_transformed[param]}")
     23 print(f"Log prob: {model.logp().eval(point_transformed)}")

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:853, in IntervalTransform.forward(self, value, *inputs)
    852 def forward(self, value, *inputs):
--> 853     a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
    855     log_lower_distance = pt.log(value - a)
    856     log_upper_distance = pt.log(b - value)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:842, in IntervalTransform.get_a_and_b(self, inputs)
    836 def get_a_and_b(self, inputs):
    837     """Return interval bound values.
    838
    839     Also returns two boolean variables indicating whether the transform is known to be statically bounded.
    840     This is used to generate smaller graphs in the transform methods.
    841     """
--> 842     a, b = self.args_fn(*inputs)
    843     lower_bounded, upper_bounded = True, True
    844     if a is None:

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/distributions/continuous.py:180, in bounded_cont_transform.<locals>.transform_params(*args)
    178 lower, upper = None, None
    179 if bound_args_indices[0] is not None:
--> 180     lower = args[bound_args_indices[0]]
    181 if bound_args_indices[1] is not None:
    182     upper = args[bound_args_indices[1]]

IndexError: tuple index out of range

PyMC version information:

pymc: 5.10.4

Context for the issue:

I would like to be able to capture and apply forward to any transform.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions