@@ -297,7 +297,9 @@ def create_inner_out_logp(
297297def find_measurable_scans (fgraph , node ):
298298 r"""Finds `Scan`\s for which a `logprob` can be computed.
299299
300- This will convert said `Scan`\s into `MeasurableScan`\s.
300+ This will convert said `Scan`\s into `MeasurableScan`\s. It also updates
301+ random variable and value variable mappings that have been specified for
302+ parts of a `Scan`\s outputs (e.g. everything except the initial values).
301303 """
302304
303305 if not isinstance (node .op , Scan ):
@@ -306,6 +308,11 @@ def find_measurable_scans(fgraph, node):
306308 if isinstance (node .op , MeasurableScan ):
307309 return
308310
311+ rv_map_feature = getattr (fgraph , "preserve_rv_mappings" , None )
312+
313+ if rv_map_feature is None :
314+ return None # pragma: no cover
315+
309316 curr_scanargs = ScanArgs .from_node (node )
310317
311318 # Find the un-output `MeasurableVariable`s created in the inner-graph
@@ -328,6 +335,71 @@ def find_measurable_scans(fgraph, node):
328335 # TODO: Why can't we make this a `MeasurableScan`?
329336 return None
330337
338+ if not any (out in rv_map_feature .rv_values for out in node .outputs ):
339+ # We need to remap user inputs that have been specified in terms of
340+ # `Subtensor`s of this `Scan`'s node's outputs.
341+ #
342+ # For example, the output that the user got was something like
343+ # `out[1:]` for `outputs_info = [{"initial": x0, "taps": [-1]}]`, so
344+ # they likely passed `{out[1:]: x_1T_vv}` to `joint_logprob`.
345+ # Since `out[1:]` isn't really the output of a `Scan`, but a
346+ # `Subtensor` of the output `out` of a `Scan`, we need to account for
347+ # that.
348+
349+ from aesara .tensor .subtensor import Subtensor , indices_from_subtensor
350+
351+ # Get any `Subtensor` outputs that have been applied to outputs of this
352+ # `Scan` (and get the corresponding indices of the outputs from this
353+ # `Scan`)
354+ output_clients : List [Tuple [Variable , int ]] = sum (
355+ [
356+ [
357+ # This is expected to work for `Subtensor` `Op`s,
358+ # because they only ever have one output
359+ (cl .default_output (), i )
360+ for cl , _ in fgraph .get_clients (out )
361+ if isinstance (cl .op , Subtensor )
362+ ]
363+ for i , out in enumerate (node .outputs )
364+ ],
365+ [],
366+ )
367+
368+ # The second items in these tuples are the value variables mapped to
369+ # the *user-specified* measurable variables (i.e. the first items) that
370+ # are `Subtensor`s of the outputs of this `Scan`. The second items are
371+ # the index of the corresponding output of this `Scan` node.
372+ indirect_rv_vars = [
373+ (out , rv_map_feature .rv_values [out ], out_idx )
374+ for out , out_idx in output_clients
375+ if out in rv_map_feature .rv_values
376+ ]
377+
378+ if not indirect_rv_vars :
379+ return None
380+
381+ # We're going to replace the user's random variable/value variable mappings
382+ # with ones that map directly to outputs of this `Scan`.
383+ for rv_var , val_var , out_idx in indirect_rv_vars :
384+
385+ # The full/un-`Subtensor`ed `Scan` output that we need to use
386+ full_out = node .outputs [out_idx ]
387+
388+ assert rv_var .owner .inputs [0 ] == full_out
389+
390+ # A new value variable that spans the full output
391+ new_val_var = full_out .clone ()
392+ # Set the parts of this new value variable that applied to the
393+ # user-specified value variable to the user's value variable
394+ subtensor_indices = indices_from_subtensor (
395+ rv_var .owner .inputs [1 :], rv_var .owner .op .idx_list
396+ )
397+ new_val_var = at .set_subtensor (new_val_var [subtensor_indices ], val_var )
398+
399+ # Replace the mapping
400+ del rv_map_feature .rv_values [rv_var ]
401+ rv_map_feature .rv_values [full_out ] = new_val_var
402+
331403 op = MeasurableScan (
332404 curr_scanargs .inner_inputs , curr_scanargs .inner_outputs , curr_scanargs .info
333405 )
0 commit comments