65
65
TensorLike : TypeAlias = Union [Variable , float , np .ndarray ]
66
66
67
67
68
- def logp (rv : TensorVariable , value : TensorLike , ** kwargs ) -> TensorVariable :
68
+ def _warn_rvs_in_inferred_graph (graph : Sequence [TensorVariable ]):
69
+ """Issue warning if any RVs are found in graph.
70
+
71
+ RVs are usually an (implicit) conditional input of the derived probability expression,
72
+ and meant to be replaced by respective value variables before evaluation.
73
+ However, when the IR graph is built, any non-input nodes (including RVs) are cloned,
74
+ breaking the link with the original ones.
75
+ This makes it impossible (or difficult) to replace it by the respective values afterward,
76
+ so we instruct users to do it beforehand.
77
+ """
78
+ from pymc .testing import assert_no_rvs
79
+
80
+ try :
81
+ assert_no_rvs (graph )
82
+ except AssertionError :
83
+ warnings .warn (
84
+ "RandomVariables were found in the derived graph. "
85
+ "These variables are a clone and do not match the original ones on identity.\n "
86
+ "If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
87
+ "`logp(model.replace_rvs_by_values([rv])[0], value)`" ,
88
+ stacklevel = 3 ,
89
+ )
90
+
91
+
92
+ def logp (
93
+ rv : TensorVariable , value : TensorLike , warn_missing_rvs : bool = True , ** kwargs
94
+ ) -> TensorVariable :
69
95
"""Return the log-probability graph of a Random Variable"""
70
96
71
97
value = pt .as_tensor_variable (value , dtype = rv .dtype )
@@ -74,10 +100,15 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
74
100
except NotImplementedError :
75
101
fgraph , _ , _ = construct_ir_fgraph ({rv : value })
76
102
[(ir_rv , ir_value )] = fgraph .preserve_rv_mappings .rv_values .items ()
77
- return _logprob_helper (ir_rv , ir_value , ** kwargs )
103
+ expr = _logprob_helper (ir_rv , ir_value , ** kwargs )
104
+ if warn_missing_rvs :
105
+ _warn_rvs_in_inferred_graph (expr )
106
+ return expr
78
107
79
108
80
- def logcdf (rv : TensorVariable , value : TensorLike , ** kwargs ) -> TensorVariable :
109
+ def logcdf (
110
+ rv : TensorVariable , value : TensorLike , warn_missing_rvs : bool = True , ** kwargs
111
+ ) -> TensorVariable :
81
112
"""Create a graph for the log-CDF of a Random Variable."""
82
113
value = pt .as_tensor_variable (value , dtype = rv .dtype )
83
114
try :
@@ -86,10 +117,15 @@ def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
86
117
# Try to rewrite rv
87
118
fgraph , rv_values , _ = construct_ir_fgraph ({rv : value })
88
119
[ir_rv ] = fgraph .outputs
89
- return _logcdf_helper (ir_rv , value , ** kwargs )
120
+ expr = _logcdf_helper (ir_rv , value , ** kwargs )
121
+ if warn_missing_rvs :
122
+ _warn_rvs_in_inferred_graph (expr )
123
+ return expr
90
124
91
125
92
- def icdf (rv : TensorVariable , value : TensorLike , ** kwargs ) -> TensorVariable :
126
+ def icdf (
127
+ rv : TensorVariable , value : TensorLike , warn_missing_rvs : bool = True , ** kwargs
128
+ ) -> TensorVariable :
93
129
"""Create a graph for the inverse CDF of a Random Variable."""
94
130
value = pt .as_tensor_variable (value , dtype = rv .dtype )
95
131
try :
@@ -98,7 +134,10 @@ def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
98
134
# Try to rewrite rv
99
135
fgraph , rv_values , _ = construct_ir_fgraph ({rv : value })
100
136
[ir_rv ] = fgraph .outputs
101
- return _icdf_helper (ir_rv , value , ** kwargs )
137
+ expr = _icdf_helper (ir_rv , value , ** kwargs )
138
+ if warn_missing_rvs :
139
+ _warn_rvs_in_inferred_graph (expr )
140
+ return expr
102
141
103
142
104
143
def factorized_joint_logprob (
@@ -215,7 +254,8 @@ def factorized_joint_logprob(
215
254
if warn_missing_rvs :
216
255
warnings .warn (
217
256
"Found a random variable that was neither among the observations "
218
- f"nor the conditioned variables: { node .outputs } "
257
+ f"nor the conditioned variables: { outputs } .\n "
258
+ "This variables is a clone and does not match the original one on identity."
219
259
)
220
260
continue
221
261
0 commit comments