Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Conversation

@kc611
Copy link
Collaborator

@kc611 kc611 commented Jul 3, 2021

This PR aims to implements logprob for Scan Ops. Linked Issue: #23

import aesara
import aesara.tensor as at


k = at.iscalar("k")

Y_rv, _ = aesara.scan(
    fn=lambda Y_tm1: Y_tm1 + at.random.normal(name="Y_t"),
    outputs_info=[{"initial": at.as_tensor([0.0]), "taps": [-1]}],
    n_steps=k,
    name="Y"
)

y_vv = Y_rv.clone()
y_vv.name = "y"

logp = joint_logprob(Y_rv, {Y_rv: y_vv})

Note : I am building this on top of replacement logic from #19


  • Setting up a general Scan replacement framework from symbolic-pymc.
  • Calculating log-probability for Scan's output arguments 'all-at-once' in logprob_ScanRV.
    • Passing proper values to convert_outer_out_to_in from logprob_ScanRV
    • Using the new_outer_input_vars mapping to act as values for log probability of respective inner-graph output nodes.
      • Figure out how to split value of each node according to iterations of Scan. (i.e. taps logic)
      • Calculating joint-log-probability in create_inner_out_logp instead of a normal _logprob call for each node involved.
  • Investigating and handling test failures/unexpected behaviors.
    • Find out cause of test failure in test_scan_logprob
    • Figuring out how to handle RandomStream updates in Aesara's FunctionGraphs so that they do not require explicit handling in aeppl

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to copy-over/create rewrites like push_out_rvs_from_scan and convert_outer_out_to_in so that the Scans can be further rewritten. I'm not sure how much of the original symbolic_pymc Scan code can be used here, but I know that a lot of important steps for building a log-probability graph from a Scan are in that old code.

If anything, I should see if I have any illustrations of what these old rewrites do; that would help a lot.

@brandonwillard brandonwillard added enhancement New feature or request important This label is used to indicate priority over things not given this label labels Jul 7, 2021
@brandonwillard
Copy link
Member

This PR depends on aesara-devs/aesara#510.

@brandonwillard brandonwillard force-pushed the add_scan_log branch 3 times, most recently from 57004a2 to c603b26 Compare July 17, 2021 23:43
@kc611 kc611 force-pushed the add_scan_log branch 2 times, most recently from 3ac88c4 to 3549ece Compare July 21, 2021 07:18
@brandonwillard brandonwillard force-pushed the add_scan_log branch 2 times, most recently from 9b3d10d to cb5b396 Compare July 25, 2021 23:13
@brandonwillard
Copy link
Member

brandonwillard commented Jul 25, 2021

I've added some missing tests from symbolic-pymc. These should help clarify how the Scan rewrite functions are expected to perform, and whether or not they're in working order.

Also, I imagine we'll need aesara-devs/aesara#534 to go through before this will work, because the automatic cloning that ScanArgs currently performs is not what the borrowed symbolic-pymc code expects.

@brandonwillard brandonwillard changed the title Add support for Scan Ops in joint_logprob Add support for Scan Ops in joint_logprob Jul 25, 2021
@codecov
Copy link

codecov bot commented Jul 26, 2021

Codecov Report

Merging #24 (7dd1b29) into main (d24eee8) will increase coverage by 0.14%.
The diff coverage is 95.10%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #24      +/-   ##
==========================================
+ Coverage   94.69%   94.83%   +0.14%     
==========================================
  Files           7        8       +1     
  Lines        1017     1201     +184     
  Branches      122      160      +38     
==========================================
+ Hits          963     1139     +176     
- Misses         26       30       +4     
- Partials       28       32       +4     
Impacted Files Coverage Δ
aeppl/scan.py 95.10% <95.10%> (ø)
aeppl/joint_logprob.py 100.00% <0.00%> (+1.44%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d24eee8...7dd1b29. Read the comment docs.

@kc611
Copy link
Collaborator Author

kc611 commented Aug 11, 2021

And the current issue at hand is for some reason, the rv_remapper (the PreserveRVMappings feature) doesn't update variables other than var itself. (When running the newly added test_scan_logprob_basic the push_out_rvs optimization returns mappings for both S_r and Y_rv but the only one that's updated in this dictionary is Y_rvwhich corresponds to the var for which we're getting log probability for)

lifted_rv_values = rv_remapper.rv_values

Is that intended ?

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed a simple fix and test updates.

@brandonwillard brandonwillard force-pushed the add_scan_log branch 6 times, most recently from 8cd0363 to 8af1ed4 Compare September 5, 2021 05:07
@brandonwillard
Copy link
Member

I just pushed some changes that should put everything in general working order.

The convert_outer_out_to_in tests need to be rewritten so that they test the output of that function directly, and a couple more Scan cases would be nice (e.g. using joint_logprob on the model in test_convert_outer_out_to_in_sit_sot).

@brandonwillard
Copy link
Member

brandonwillard commented Sep 18, 2021

Let's say we wanted to compute the log-probability for a simple Markov chain like the following:

import numpy as np

import aesara
import aesara.tensor as at

from aeppl.joint_logprob import joint_logprob


aesara.config.on_opt_error = "raise"

srng = at.random.RandomStream()

S_0_rv = srng.categorical(np.array([0.5, 0.5]), name="S_0")

s_0_vv = S_0_rv.clone()
s_0_vv.name = "s_0"


def step_fn(S_tm1):
    S_t = srng.categorical(np.array([0.5, 0.5]), name="S_t")
    return S_t


S_1T_rv, _ = aesara.scan(
    fn=step_fn,
    outputs_info=[{"initial": S_0_rv, "taps": [-1]}],
    strict=True,
    n_steps=10,
    name="S_0T"
)

S_1T_rv.name = "S_1T"
s_1T_vv = S_1T_rv.clone()
s_1T_vv.name = "s_1T"

If we try to construct the log-probability graph, we get an error:

S_0T_logp = joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/tmp/user/1000/babel-9XUDYv/python-rnH85p in <module>
----> 1 S_0T_logp = joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})

~/projects/code/python/aeppl/aeppl/joint_logprob.py in joint_logprob(sum, *args, **kwargs)
    218
    219     """
--> 220     logprob = factorized_joint_logprob(*args, **kwargs)
    221     if not logprob:
    222         return None

~/projects/code/python/aeppl/aeppl/joint_logprob.py in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
    157         q_rv_value_vars = [
    158             replacements[q_rv_var]
--> 159             for q_rv_var in outputs
    160             if not getattr(q_rv_var.tag, "ignore_logprob", False)
    161         ]

~/projects/code/python/aeppl/aeppl/joint_logprob.py in <listcomp>(.0)
    158             replacements[q_rv_var]
    159             for q_rv_var in outputs
--> 160             if not getattr(q_rv_var.tag, "ignore_logprob", False)
    161         ]
    162

KeyError: 'for{cpu,scan_fn}.0\n

The following two sections summarize the two main features we need to implement before merging.

Add MeasurableVariable detection for Subtensor aesara.scan output

The first issue is that S_1T_rv is actually the output of a Subtensor:

print(S_1T_rv.owner)
Subtensor{int64::}(for{cpu,S_0T}.0, ScalarConstant{1})

In other words, S_1T_rv = S_0T_rv[1:] for some S_0T_rv that effectively represents at.concatenate([at.atleast_1d(S_0_rv), S_1T_rv]).

It is possible to (partially) work around the error above by manually extracting the direct output of the Scan node–i.e. the output corresponding to S_0T_rv (e.g. via S_1T_rv.owner.inputs[0]); however, we still need to support the joint_logprob arguments/usage demonstrated in the example above, since it's considerably more intuitive. Also, there are some issues with this work-around that will be described below.

We should be able to make some updates to find_measurable_scans so that it checks for Subtensors acting on Scan outputs and somehow checks these Subtensors against the Scan's TAPS configuration (i.e. attempt to confirm that the Subtensor is the output of the aesara.scan helper function).

Support initial values

The second issue is that the initial TAPS values (i.e. the "initial" values specified in a call to aesara.scan) are not currently included in the resulting log-probability graphs.

If we extract the underlying S_0T_rv variables and use that in a call to joint_logprob, we can get a log-probability graph:

S_0T_rv = S_1T_rv.owner.inputs[0]

s_0T_vv = S_0T_rv.clone()
s_0T_vv.name = "s_0T"

S_0T_logp = joint_logprob({S_0T_rv: s_0T_vv, S_0_rv: s_0_vv})


aesara.dprint(S_0T_logp, depth=4)
Sum{acc_dtype=float64} [id A] ''
 |MakeVector{dtype='float64'} [id B] ''
   |Sum{acc_dtype=float64} [id C] ''
   | |Assert{msg='0 <= p <= 1'} [id D] 's_0_logprob'
   |Sum{acc_dtype=float64} [id E] ''
     |for{cpu,scan_fn} [id F] 's_0T_logprob'

Inner graphs of the scan ops:

for{cpu,scan_fn} [id F] 's_0T_logprob'
 >Assert{msg='0 <= p <= 1'} [id G] 'S_t_vv_logprob'
 > |Elemwise{switch,no_inplace} [id H] ''
 > | |Elemwise{and_,no_inplace} [id I] ''
 > | | |Elemwise{le,no_inplace} [id J] ''
 > | | |Elemwise{lt,no_inplace} [id K] ''
 > | |Elemwise{log,no_inplace} [id L] ''
 > | | |Subtensor{int64} [id M] ''
 > | |TensorConstant{-inf} [id N]
 > |All [id O] ''
 > | |Elemwise{ge,no_inplace} [id P] ''
 > |   |Elemwise{true_div,no_inplace} [id Q] ''
 > |   |InplaceDimShuffle{x} [id R] ''
 > |All [id S] ''
 >   |Elemwise{le,no_inplace} [id T] ''
 >     |Elemwise{true_div,no_inplace} [id Q] ''
 >     |InplaceDimShuffle{x} [id U] ''

As this output demonstrates, there is a log-probability term for S_0_rv prior (i.e. the sub-graph labeled "s_0_logprob") and another log-probability term for the Scan/state sequence–labeled "s_0T_logprob"–that depends on the s_0_vv vector; however, the corresponding (outer-)input for the Scan doesn't depend on s_0_vv in any way.

Let's take a look at the inputs of the model's Scan node:

import pprint

from aesara.scan.utils import ScanArgs

## Prevent `dict` sorting
pprint.sorted = lambda arg, *a, **kw: arg

scan_args = ScanArgs.from_node(S_1T_rv.owner.inputs[0].owner)


def labeled_outer_inputs(scan_args):
    """Get ordered, labeled outer inputs for a `Scan`."""
    outer_inputs = [
        "n_steps",
        "outer_in_seqs",
        "outer_in_mit_mot",
        "outer_in_mit_sot",
        "outer_in_sit_sot",
        "outer_in_shared",
        "outer_in_nit_sot",
        "outer_in_non_seqs",
    ]

    return {
        name: getattr(scan_args, name)
        for name in outer_inputs
        if getattr(scan_args, name, None)
    }


model_outer_inputs = labeled_outer_inputs(scan_args)
pprint.pprint(model_outer_inputs)
{'n_steps': TensorConstant{10},
 'outer_in_sit_sot': [IncSubtensor{Set;:int64:}.0],
 'outer_in_shared': [RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F9D44833D70>)]}

The 'outer_in_sit_sot' input is the one that corresponds with the state sequence, and it is a vector with its first input set to s_0_vv:

aesara.dprint(model_outer_inputs["outer_in_sit_sot"])
IncSubtensor{Set;:int64:} [id A] ''
 |AllocEmpty{dtype='int64'} [id B] ''
 | |Elemwise{add,no_inplace} [id C] ''
 |   |TensorConstant{10} [id D]
 |   |Subtensor{int64} [id E] ''
 |     |Shape [id F] ''
 |     | |Rebroadcast{0} [id G] ''
 |     |   |InplaceDimShuffle{x} [id H] ''
 |     |     |categorical_rv{0, (1,), int64, False}.1 [id I] 'S_0'
 |     |       |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F9D44A1B6E0>) [id J]
 |     |       |TensorConstant{[]} [id K]
 |     |       |TensorConstant{4} [id L]
 |     |       |TensorConstant{(2,) of 0.5} [id M]
 |     |ScalarConstant{0} [id N]
 |Rebroadcast{0} [id G] ''
 |ScalarFromTensor [id O] ''
   |Subtensor{int64} [id E] ''

Our convert_outer_out_to_in needs to identify the original full-sequence variable (i.e. the AllocEmpty with ID B) and replace it with its value variable.

@brandonwillard
Copy link
Member

brandonwillard commented Sep 22, 2021

I just added a commit that handles the first remaining issue (i.e. the section in #24 (comment) titled "Add MeasurableVariable detection for Subtensor aesara.scan output").

@brandonwillard
Copy link
Member

brandonwillard commented Oct 1, 2021

I just added the functionality for the last missing feature. It needs more testing, but at least it seems to work for the one simple implemented test case.

@brandonwillard
Copy link
Member

All right, I'm going to merge this for now; we can put in updates as we come across issues and such.

@brandonwillard brandonwillard merged commit 49d32bc into aesara-devs:main Oct 6, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

enhancement New feature or request important This label is used to indicate priority over things not given this label

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add general timeseries support for joint_logprob

3 participants