@@ -129,11 +129,24 @@ class MeasurableMakeVector(MakeVector):
129
129
130
130
131
131
@_logprob .register (MeasurableMakeVector )
132
- def logprob_make_vector (op , values , * base_vars , ** kwargs ):
132
+ def logprob_make_vector (op , values , * base_rvs , ** kwargs ):
133
133
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
134
+ # TODO: Sort out this circular dependency issue
135
+ from pymc .pytensorf import replace_rvs_by_values
136
+
134
137
(value ,) = values
135
138
136
- return at .stack ([logprob (base_var , value [i ]) for i , base_var in enumerate (base_vars )])
139
+ base_rvs_to_values = {base_rv : value [i ] for i , base_rv in enumerate (base_rvs )}
140
+ for i , (base_rv , value ) in enumerate (base_rvs_to_values .items ()):
141
+ base_rv .name = f"base_rv[{ i } ]"
142
+ value .name = f"value[{ i } ]"
143
+
144
+ logps = [logprob (base_rv , value ) for base_rv , value in base_rvs_to_values .items ()]
145
+
146
+ # If the stacked variables depend on each other, we have to replace them by the respective values
147
+ logps = replace_rvs_by_values (logps , rvs_to_values = base_rvs_to_values )
148
+
149
+ return at .stack (logps )
137
150
138
151
139
152
class MeasurableJoin (Join ):
@@ -144,27 +157,28 @@ class MeasurableJoin(Join):
144
157
145
158
146
159
@_logprob .register (MeasurableJoin )
147
- def logprob_join (op , values , axis , * base_vars , ** kwargs ):
160
+ def logprob_join (op , values , axis , * base_rvs , ** kwargs ):
148
161
"""Compute the log-likelihood graph for a `Join`."""
149
- (value ,) = values
162
+ # TODO: Find better way to avoid circular dependency
163
+ from pymc .pytensorf import constant_fold , replace_rvs_by_values
150
164
151
- base_var_shapes = [ base_var . shape [ axis ] for base_var in base_vars ]
165
+ ( value ,) = values
152
166
153
- # TODO: Find better way to avoid circular dependency
154
- from pymc .pytensorf import constant_fold
167
+ base_rv_shapes = [base_var .shape [axis ] for base_var in base_rvs ]
155
168
156
169
# We don't need the graph to be constant, just to have RandomVariables removed
157
- base_var_shapes = constant_fold (base_var_shapes , raise_not_constant = False )
170
+ base_rv_shapes = constant_fold (base_rv_shapes , raise_not_constant = False )
158
171
159
172
split_values = at .split (
160
173
value ,
161
- splits_size = base_var_shapes ,
162
- n_splits = len (base_vars ),
174
+ splits_size = base_rv_shapes ,
175
+ n_splits = len (base_rvs ),
163
176
axis = axis ,
164
177
)
165
178
179
+ base_rvs_to_split_values = {base_rv : value for base_rv , value in zip (base_rvs , split_values )}
166
180
logps = [
167
- logprob (base_var , split_value ) for base_var , split_value in zip ( base_vars , split_values )
181
+ logprob (base_var , split_value ) for base_var , split_value in base_rvs_to_split_values . items ( )
168
182
]
169
183
170
184
if len ({logp .ndim for logp in logps }) != 1 :
@@ -173,12 +187,12 @@ def logprob_join(op, values, axis, *base_vars, **kwargs):
173
187
"joining univariate and multivariate distributions" ,
174
188
)
175
189
190
+ # If the stacked variables depend on each other, we have to replace them by the respective values
191
+ logps = replace_rvs_by_values (logps , rvs_to_values = base_rvs_to_split_values )
192
+
176
193
base_vars_ndim_supp = split_values [0 ].ndim - logps [0 ].ndim
177
194
join_logprob = at .concatenate (
178
- [
179
- at .atleast_1d (logprob (base_var , split_value ))
180
- for base_var , split_value in zip (base_vars , split_values )
181
- ],
195
+ [at .atleast_1d (logp ) for logp in logps ],
182
196
axis = axis - base_vars_ndim_supp ,
183
197
)
184
198
@@ -199,6 +213,8 @@ def find_measurable_stacks(
199
213
if rv_map_feature is None :
200
214
return None # pragma: no cover
201
215
216
+ rvs_to_values = rv_map_feature .rv_values
217
+
202
218
stack_out = node .outputs [0 ]
203
219
204
220
is_join = isinstance (node .op , Join )
@@ -211,18 +227,40 @@ def find_measurable_stacks(
211
227
if not all (
212
228
base_var .owner
213
229
and isinstance (base_var .owner .op , MeasurableVariable )
214
- and base_var not in rv_map_feature . rv_values
230
+ and base_var not in rvs_to_values
215
231
for base_var in base_vars
216
232
):
217
233
return None # pragma: no cover
218
234
219
235
# Make base_vars unmeasurable
220
- base_vars = [assign_custom_measurable_outputs (base_var .owner ) for base_var in base_vars ]
236
+ base_to_unmeasurable_vars = {
237
+ base_var : assign_custom_measurable_outputs (base_var .owner ).outputs [
238
+ base_var .owner .outputs .index (base_var )
239
+ ]
240
+ for base_var in base_vars
241
+ }
242
+
243
+ def replacement_fn (var , replacements ):
244
+ if var in base_to_unmeasurable_vars :
245
+ replacements [var ] = base_to_unmeasurable_vars [var ]
246
+ # We don't want to clone valued nodes. Assigning a var to itself in the
247
+ # replacements prevents this
248
+ elif var in rvs_to_values :
249
+ replacements [var ] = var
250
+
251
+ return []
252
+
253
+ # TODO: Fix this import circularity!
254
+ from pymc .pytensorf import _replace_rvs_in_graphs
255
+
256
+ unmeasurable_base_vars , _ = _replace_rvs_in_graphs (
257
+ graphs = base_vars , replacement_fn = replacement_fn
258
+ )
221
259
222
260
if is_join :
223
- measurable_stack = MeasurableJoin ()(axis , * base_vars )
261
+ measurable_stack = MeasurableJoin ()(axis , * unmeasurable_base_vars )
224
262
else :
225
- measurable_stack = MeasurableMakeVector (node .op .dtype )(* base_vars )
263
+ measurable_stack = MeasurableMakeVector (node .op .dtype )(* unmeasurable_base_vars )
226
264
227
265
measurable_stack .name = stack_out .name
228
266
0 commit comments