Description
Describe the issue:
Description
I wish to evaluate the log probability of a model at a specific point. Since some RVs may be transformed, I must apply the transformation to the input point in order to pass the transformed point to model.logp
. If one of the RV transforms is Interval
, then the call to forward()
fails with an obscure error.
Expected Behavior
Applying forward
to a RV transformed via Interval
should not raise an exception.
Actual Behavior
Applying forward
to a RV transformed via Interval
raises an exception.
Minimum working example
The following example demonstrates what I want to do. I have a model where x
is not transformed, y
has a LogTransform
, and z
has an Interval
transform. I want to evaluate logp
at a given x, y, z
. Since logp
requires the transformed point in order to evaluate the expression, I must apply the transformations and pass the transformed point to logp().eval()
.
The output of this MWE is below:
5.10.4
name: x param: x transform: None
value: 1.0 transformed_value: 1.0
name: y param: y_log__ transform: LogTransform
value: 0.5 transformed_value: -0.6931471824645996
name: z param: z_interval__ transform: Interval
<Exception>
Reproduceable code example:
import pymc as pm
print(pm.__version__)
with pm.Model() as model:
x = pm.Normal("x", mu=0.0, sigma=1.0)
y = pm.LogNormal("y", mu=1.0, sigma=1.0)
z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)
point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
name = rv.name
param = model.rvs_to_values[rv]
transform = model.rvs_to_transforms[rv]
print(f"name: {name} param: {param} transform: {transform}")
if transform is None:
point_transformed[param] = point[name]
else:
point_transformed[param] = transform.forward(point[name]).eval()
print(f"value: {point[name]} transformed_value: {point_transformed[param]}")
print(f"Log prob: {model.logp().eval(point_transformed)}")
Error message:
IndexError Traceback (most recent call last)
Cell In[25], line 20
18 point_transformed[param] = point[name]
19 else:
---> 20 point_transformed[param] = transform.forward(point[name]).eval()
21 print(f"value: {point[name]} transformed_value: {point_transformed[param]}")
23 print(f"Log prob: {model.logp().eval(point_transformed)}")
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:853, in IntervalTransform.forward(self, value, *inputs)
852 def forward(self, value, *inputs):
--> 853 a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
855 log_lower_distance = pt.log(value - a)
856 log_upper_distance = pt.log(b - value)
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:842, in IntervalTransform.get_a_and_b(self, inputs)
836 def get_a_and_b(self, inputs):
837 """Return interval bound values.
838
839 Also returns two boolean variables indicating whether the transform is known to be statically bounded.
840 This is used to generate smaller graphs in the transform methods.
841 """
--> 842 a, b = self.args_fn(*inputs)
843 lower_bounded, upper_bounded = True, True
844 if a is None:
File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/distributions/continuous.py:180, in bounded_cont_transform.<locals>.transform_params(*args)
178 lower, upper = None, None
179 if bound_args_indices[0] is not None:
--> 180 lower = args[bound_args_indices[0]]
181 if bound_args_indices[1] is not None:
182 upper = args[bound_args_indices[1]]
IndexError: tuple index out of range
PyMC version information:
Context for the issue:
I would like to be able to capture and apply forward
to any transform.