Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 4d27b62

Browse files
Handle value mappings for Scans with taps and initial values
1 parent 375f098 commit 4d27b62

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

aeppl/scan.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def create_inner_out_logp(
297297
def find_measurable_scans(fgraph, node):
298298
r"""Finds `Scan`\s for which a `logprob` can be computed.
299299
300-
This will convert said `Scan`\s into `MeasurableScan`\s.
300+
This will convert said `Scan`\s into `MeasurableScan`\s. It also updates
301+
random variable and value variable mappings that have been specified for
302+
parts of a `Scan`\s outputs (e.g. everything except the initial values).
301303
"""
302304

303305
if not isinstance(node.op, Scan):
@@ -306,6 +308,11 @@ def find_measurable_scans(fgraph, node):
306308
if isinstance(node.op, MeasurableScan):
307309
return
308310

311+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
312+
313+
if rv_map_feature is None:
314+
return None # pragma: no cover
315+
309316
curr_scanargs = ScanArgs.from_node(node)
310317

311318
# Find the un-output `MeasurableVariable`s created in the inner-graph
@@ -328,6 +335,71 @@ def find_measurable_scans(fgraph, node):
328335
# TODO: Why can't we make this a `MeasurableScan`?
329336
return None
330337

338+
if not any(out in rv_map_feature.rv_values for out in node.outputs):
339+
# We need to remap user inputs that have been specified in terms of
340+
# `Subtensor`s of this `Scan`'s node's outputs.
341+
#
342+
# For example, the output that the user got was something like
343+
# `out[1:]` for `outputs_info = [{"initial": x0, "taps": [-1]}]`, so
344+
# they likely passed `{out[1:]: x_1T_vv}` to `joint_logprob`.
345+
# Since `out[1:]` isn't really the output of a `Scan`, but a
346+
# `Subtensor` of the output `out` of a `Scan`, we need to account for
347+
# that.
348+
349+
from aesara.tensor.subtensor import Subtensor, indices_from_subtensor
350+
351+
# Get any `Subtensor` outputs that have been applied to outputs of this
352+
# `Scan` (and get the corresponding indices of the outputs from this
353+
# `Scan`)
354+
output_clients: List[Tuple[Variable, int]] = sum(
355+
[
356+
[
357+
# This is expected to work for `Subtensor` `Op`s,
358+
# because they only ever have one output
359+
(cl.default_output(), i)
360+
for cl, _ in fgraph.get_clients(out)
361+
if isinstance(cl.op, Subtensor)
362+
]
363+
for i, out in enumerate(node.outputs)
364+
],
365+
[],
366+
)
367+
368+
# The second items in these tuples are the value variables mapped to
369+
# the *user-specified* measurable variables (i.e. the first items) that
370+
# are `Subtensor`s of the outputs of this `Scan`. The second items are
371+
# the index of the corresponding output of this `Scan` node.
372+
indirect_rv_vars = [
373+
(out, rv_map_feature.rv_values[out], out_idx)
374+
for out, out_idx in output_clients
375+
if out in rv_map_feature.rv_values
376+
]
377+
378+
if not indirect_rv_vars:
379+
return None
380+
381+
# We're going to replace the user's random variable/value variable mappings
382+
# with ones that map directly to outputs of this `Scan`.
383+
for rv_var, val_var, out_idx in indirect_rv_vars:
384+
385+
# The full/un-`Subtensor`ed `Scan` output that we need to use
386+
full_out = node.outputs[out_idx]
387+
388+
assert rv_var.owner.inputs[0] == full_out
389+
390+
# A new value variable that spans the full output
391+
new_val_var = full_out.clone()
392+
# Set the parts of this new value variable that applied to the
393+
# user-specified value variable to the user's value variable
394+
subtensor_indices = indices_from_subtensor(
395+
rv_var.owner.inputs[1:], rv_var.owner.op.idx_list
396+
)
397+
new_val_var = at.set_subtensor(new_val_var[subtensor_indices], val_var)
398+
399+
# Replace the mapping
400+
del rv_map_feature.rv_values[rv_var]
401+
rv_map_feature.rv_values[full_out] = new_val_var
402+
331403
op = MeasurableScan(
332404
curr_scanargs.inner_inputs, curr_scanargs.inner_outputs, curr_scanargs.info
333405
)

tests/test_scan.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,34 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
341341
y_logp_ref_val = y_logp_ref.eval(test_point)
342342

343343
assert np.allclose(y_logp_val, y_logp_ref_val)
344+
345+
346+
def test_initial_values():
347+
srng = at.random.RandomStream()
348+
349+
S_0_rv = srng.categorical(np.array([0.5, 0.5]), name="S_0")
350+
351+
s_0_vv = S_0_rv.clone()
352+
s_0_vv.name = "s_0"
353+
354+
def step_fn(S_tm1):
355+
S_t = srng.categorical(np.array([0.5, 0.5]), name="S_t")
356+
return S_t
357+
358+
S_1T_rv, _ = aesara.scan(
359+
fn=step_fn,
360+
outputs_info=[{"initial": S_0_rv, "taps": [-1]}],
361+
strict=True,
362+
n_steps=10,
363+
name="S_0T",
364+
)
365+
366+
S_1T_rv.name = "S_1T"
367+
s_1T_vv = S_1T_rv.clone()
368+
s_1T_vv.name = "s_1T"
369+
370+
S_0T_logp = joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})
371+
372+
assert S_0T_logp
373+
374+
raise AssertionError("Not finished")

0 commit comments

Comments
 (0)