Skip to content

BUG: Unable to apply Interval foward transformation #7193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
tvwenger opened this issue Mar 13, 2024 · 3 comments
Closed

BUG: Unable to apply Interval foward transformation #7193

tvwenger opened this issue Mar 13, 2024 · 3 comments
Labels

Comments

@tvwenger
Copy link
Contributor

tvwenger commented Mar 13, 2024

Describe the issue:

Description

I wish to evaluate the log probability of a model at a specific point. Since some RVs may be transformed, I must apply the transformation to the input point in order to pass the transformed point to model.logp. If one of the RV transforms is Interval, then the call to forward() fails with an obscure error.

Expected Behavior

Applying forward to a RV transformed via Interval should not raise an exception.

Actual Behavior

Applying forward to a RV transformed via Interval raises an exception.

Minimum working example

The following example demonstrates what I want to do. I have a model where x is not transformed, y has a LogTransform, and z has an Interval transform. I want to evaluate logp at a given x, y, z. Since logp requires the transformed point in order to evaluate the expression, I must apply the transformations and pass the transformed point to logp().eval().

The output of this MWE is below:

5.10.4
name: x param: x transform: None
value: 1.0 transformed_value: 1.0
name: y param: y_log__ transform: LogTransform
value: 0.5 transformed_value: -0.6931471824645996
name: z param: z_interval__ transform: Interval
<Exception>

Reproduceable code example:

import pymc as pm

print(pm.__version__)

with pm.Model() as model:
    x = pm.Normal("x", mu=0.0, sigma=1.0)
    y = pm.LogNormal("y", mu=1.0, sigma=1.0)
    z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)

point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
    name = rv.name
    param = model.rvs_to_values[rv]
    transform = model.rvs_to_transforms[rv]
    print(f"name: {name} param: {param} transform: {transform}")
    if transform is None:
        point_transformed[param] = point[name]
    else:
        point_transformed[param] = transform.forward(point[name]).eval()
    print(f"value: {point[name]} transformed_value: {point_transformed[param]}")

print(f"Log prob: {model.logp().eval(point_transformed)}")

Error message:

IndexError                                Traceback (most recent call last)
Cell In[25], line 20
     18         point_transformed[param] = point[name]
     19     else:
---> 20         point_transformed[param] = transform.forward(point[name]).eval()
     21     print(f"value: {point[name]} transformed_value: {point_transformed[param]}")
     23 print(f"Log prob: {model.logp().eval(point_transformed)}")

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:853, in IntervalTransform.forward(self, value, *inputs)
    852 def forward(self, value, *inputs):
--> 853     a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
    855     log_lower_distance = pt.log(value - a)
    856     log_upper_distance = pt.log(b - value)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:842, in IntervalTransform.get_a_and_b(self, inputs)
    836 def get_a_and_b(self, inputs):
    837     """Return interval bound values.
    838
    839     Also returns two boolean variables indicating whether the transform is known to be statically bounded.
    840     This is used to generate smaller graphs in the transform methods.
    841     """
--> 842     a, b = self.args_fn(*inputs)
    843     lower_bounded, upper_bounded = True, True
    844     if a is None:

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/distributions/continuous.py:180, in bounded_cont_transform.<locals>.transform_params(*args)
    178 lower, upper = None, None
    179 if bound_args_indices[0] is not None:
--> 180     lower = args[bound_args_indices[0]]
    181 if bound_args_indices[1] is not None:
    182     upper = args[bound_args_indices[1]]

IndexError: tuple index out of range

PyMC version information:

pymc: 5.10.4

Context for the issue:

I would like to be able to capture and apply forward to any transform.

@tvwenger tvwenger added the bug label Mar 13, 2024
@ricardoV94
Copy link
Member

ricardoV94 commented Mar 13, 2024

The signature of forward expects all inputs of the RV besides the value.

For your goal, if you don't care about the jacobian term you can remove the value transforms with https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

Otherwise you need to define a function that pushes all the values forward. That's not straightforward and something we want to offer users: #6721

Help there would be much welcome.

Unless I missed something I would close this issue as duplicated. The code doesn't show a bug but an incorrect use of the transform objects

@tvwenger
Copy link
Contributor Author

@ricardoV94 Thanks for the insight. Following your response here I was able to get this to work by also passing rv.owner.inputs to forward:

import pymc as pm

print(pm.__version__)

with pm.Model() as model:
    x = pm.Normal("x", mu=0.0, sigma=1.0)
    y = pm.LogNormal("y", mu=1.0, sigma=1.0)
    z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)

point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
    name = rv.name
    param = model.rvs_to_values[rv]
    transform = model.rvs_to_transforms[rv]
    print(f"name: {name} param: {param} transform: {transform}")
    if transform is None:
        point_transformed[param] = point[name]
    else:
        point_transformed[param] = transform.forward(
            point[name], *rv.owner.inputs
        ).eval()
    print(f"value: {point[name]} transformed_value: {point_transformed[param]}")

print(f"Log prob: {model.logp().eval(point_transformed)}")

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 13, 2024

Note that only works because lower/upper are constants. If they depended on other parameters you would get wrong results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants