Skip to content

Commit 09dc9d0

Browse files
committed
Raise warning if RVs are present in derived probability graphs
1 parent 30c9179 commit 09dc9d0

File tree

2 files changed

+102
-9
lines changed

2 files changed

+102
-9
lines changed

pymc/logprob/basic.py

+47-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,33 @@
6565
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]
6666

6767

68-
def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
68+
def _warn_rvs_in_inferred_graph(graph: Sequence[TensorVariable]):
69+
"""Issue warning if any RVs are found in graph.
70+
71+
RVs are usually an (implicit) conditional input of the derived probability expression,
72+
and meant to be replaced by respective value variables before evaluation.
73+
However, when the IR graph is built, any non-input nodes (including RVs) are cloned,
74+
breaking the link with the original ones.
75+
This makes it impossible (or difficult) to replace it by the respective values afterward,
76+
so we instruct users to do it beforehand.
77+
"""
78+
from pymc.testing import assert_no_rvs
79+
80+
try:
81+
assert_no_rvs(graph)
82+
except AssertionError:
83+
warnings.warn(
84+
"RandomVariables were found in the derived graph. "
85+
"These variables are a clone and do not match the original ones on identity.\n"
86+
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
87+
"`logp(model.replace_rvs_by_values([rv])[0], value)`",
88+
stacklevel=3,
89+
)
90+
91+
92+
def logp(
93+
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
94+
) -> TensorVariable:
6995
"""Return the log-probability graph of a Random Variable"""
7096

7197
value = pt.as_tensor_variable(value, dtype=rv.dtype)
@@ -74,10 +100,15 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
74100
except NotImplementedError:
75101
fgraph, _, _ = construct_ir_fgraph({rv: value})
76102
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
77-
return _logprob_helper(ir_rv, ir_value, **kwargs)
103+
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
104+
if warn_missing_rvs:
105+
_warn_rvs_in_inferred_graph(expr)
106+
return expr
78107

79108

80-
def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
109+
def logcdf(
110+
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
111+
) -> TensorVariable:
81112
"""Create a graph for the log-CDF of a Random Variable."""
82113
value = pt.as_tensor_variable(value, dtype=rv.dtype)
83114
try:
@@ -86,10 +117,15 @@ def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
86117
# Try to rewrite rv
87118
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
88119
[ir_rv] = fgraph.outputs
89-
return _logcdf_helper(ir_rv, value, **kwargs)
120+
expr = _logcdf_helper(ir_rv, value, **kwargs)
121+
if warn_missing_rvs:
122+
_warn_rvs_in_inferred_graph(expr)
123+
return expr
90124

91125

92-
def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
126+
def icdf(
127+
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
128+
) -> TensorVariable:
93129
"""Create a graph for the inverse CDF of a Random Variable."""
94130
value = pt.as_tensor_variable(value, dtype=rv.dtype)
95131
try:
@@ -98,7 +134,10 @@ def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
98134
# Try to rewrite rv
99135
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
100136
[ir_rv] = fgraph.outputs
101-
return _icdf_helper(ir_rv, value, **kwargs)
137+
expr = _icdf_helper(ir_rv, value, **kwargs)
138+
if warn_missing_rvs:
139+
_warn_rvs_in_inferred_graph(expr)
140+
return expr
102141

103142

104143
def factorized_joint_logprob(
@@ -215,7 +254,8 @@ def factorized_joint_logprob(
215254
if warn_missing_rvs:
216255
warnings.warn(
217256
"Found a random variable that was neither among the observations "
218-
f"nor the conditioned variables: {node.outputs}"
257+
f"nor the conditioned variables: {outputs}.\n"
258+
"This variables is a clone and does not match the original one on identity."
219259
)
220260
continue
221261

tests/logprob/test_basic.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@
5656
import pymc as pm
5757

5858
from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp
59+
from pymc.logprob.transforms import LogTransform
5960
from pymc.logprob.utils import rvs_to_value_vars, walk_model
61+
from pymc.pytensorf import replace_rvs_by_values
6062
from pymc.testing import assert_no_rvs
6163
from tests.logprob.utils import joint_logprob
6264

@@ -248,16 +250,25 @@ def test_persist_inputs():
248250
y_vv_2 = y_vv * 2
249251
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
250252

253+
assert y_vv in ancestors([logp_2])
254+
assert y_vv_2 in ancestors([logp_2])
255+
256+
# Even when they are random
257+
y_vv = pt.random.normal(name="y_vv2")
258+
y_vv_2 = y_vv * 2
259+
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
260+
261+
assert y_vv in ancestors([logp_2])
251262
assert y_vv_2 in ancestors([logp_2])
252263

253264

254-
def test_warn_random_not_found():
265+
def test_warn_random_found_factorized_joint_logprob():
255266
x_rv = pt.random.normal(name="x")
256267
y_rv = pt.random.normal(x_rv, 1, name="y")
257268

258269
y_vv = y_rv.clone()
259270

260-
with pytest.warns(UserWarning):
271+
with pytest.warns(UserWarning, match="Found a random variable that was neither among"):
261272
factorized_joint_logprob({y_rv: y_vv})
262273

263274
with warnings.catch_warnings():
@@ -457,3 +468,45 @@ def test_probability_inference_fails(func, func_name):
457468
match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}",
458469
):
459470
func(pt.cos(pm.Normal.dist()), 1)
471+
472+
473+
@pytest.mark.parametrize(
474+
"func, scipy_func, test_value",
475+
[
476+
(logp, "logpdf", 5.0),
477+
(logcdf, "logcdf", 5.0),
478+
(icdf, "ppf", 0.7),
479+
],
480+
)
481+
def test_warn_random_found_probability_inference(func, scipy_func, test_value):
482+
# Fail if unexpected warning is issued
483+
with warnings.catch_warnings():
484+
warnings.simplefilter("error")
485+
486+
input_rv = pm.Normal.dist(0, name="input")
487+
# Note: This graph could correspond to a convolution of two normals
488+
# In which case the inference should either return that or fail explicitly
489+
# For now, the lopgrob submodule treats the input as a stochastic value.
490+
rv = pt.exp(pm.Normal.dist(input_rv))
491+
with pytest.warns(UserWarning, match="RandomVariables were found in the derived graph"):
492+
assert func(rv, 0.0)
493+
494+
res = func(rv, 0.0, warn_missing_rvs=False)
495+
# This is the problem we are warning about, as now we can no longer identify the original rv in the graph
496+
# or replace it by the respective value
497+
assert rv not in ancestors([res])
498+
499+
# Test that the prescribed solution does not raise a warning and works as expected
500+
input_vv = input_rv.clone()
501+
[new_rv] = replace_rvs_by_values(
502+
[rv],
503+
rvs_to_values={input_rv: input_vv},
504+
rvs_to_transforms={input_rv: LogTransform()},
505+
)
506+
input_vv_test = 1.3
507+
np.testing.assert_almost_equal(
508+
func(new_rv, test_value).eval({input_vv: input_vv_test}),
509+
getattr(sp.lognorm(s=1, loc=0, scale=np.exp(np.exp(input_vv_test))), scipy_func)(
510+
test_value
511+
),
512+
)

0 commit comments

Comments
 (0)