-
-
Notifications
You must be signed in to change notification settings - Fork 20
Add support for Scan Ops in joint_logprob
#24
Conversation
brandonwillard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to copy-over/create rewrites like push_out_rvs_from_scan and convert_outer_out_to_in so that the Scans can be further rewritten. I'm not sure how much of the original symbolic_pymc Scan code can be used here, but I know that a lot of important steps for building a log-probability graph from a Scan are in that old code.
If anything, I should see if I have any illustrations of what these old rewrites do; that would help a lot.
|
This PR depends on aesara-devs/aesara#510. |
57004a2 to
c603b26
Compare
3ac88c4 to
3549ece
Compare
9b3d10d to
cb5b396
Compare
|
I've added some missing tests from Also, I imagine we'll need aesara-devs/aesara#534 to go through before this will work, because the automatic cloning that |
Scan Ops in joint_logprob
Codecov Report
@@ Coverage Diff @@
## main #24 +/- ##
==========================================
+ Coverage 94.69% 94.83% +0.14%
==========================================
Files 7 8 +1
Lines 1017 1201 +184
Branches 122 160 +38
==========================================
+ Hits 963 1139 +176
- Misses 26 30 +4
- Partials 28 32 +4
Continue to review full report at Codecov.
|
b87bfc4 to
a917a58
Compare
|
And the current issue at hand is for some reason, the Line 117 in 37a5228
Is that intended ? |
e2f4307 to
4f7cec7
Compare
brandonwillard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've pushed a simple fix and test updates.
4f7cec7 to
fa767cb
Compare
8cd0363 to
8af1ed4
Compare
|
I just pushed some changes that should put everything in general working order. The |
e92ddf5 to
eeebc2e
Compare
eeebc2e to
d8434f6
Compare
|
Let's say we wanted to compute the log-probability for a simple Markov chain like the following: import numpy as np
import aesara
import aesara.tensor as at
from aeppl.joint_logprob import joint_logprob
aesara.config.on_opt_error = "raise"
srng = at.random.RandomStream()
S_0_rv = srng.categorical(np.array([0.5, 0.5]), name="S_0")
s_0_vv = S_0_rv.clone()
s_0_vv.name = "s_0"
def step_fn(S_tm1):
S_t = srng.categorical(np.array([0.5, 0.5]), name="S_t")
return S_t
S_1T_rv, _ = aesara.scan(
fn=step_fn,
outputs_info=[{"initial": S_0_rv, "taps": [-1]}],
strict=True,
n_steps=10,
name="S_0T"
)
S_1T_rv.name = "S_1T"
s_1T_vv = S_1T_rv.clone()
s_1T_vv.name = "s_1T"If we try to construct the log-probability graph, we get an error: S_0T_logp = joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/tmp/user/1000/babel-9XUDYv/python-rnH85p in <module>
----> 1 S_0T_logp = joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})
~/projects/code/python/aeppl/aeppl/joint_logprob.py in joint_logprob(sum, *args, **kwargs)
218
219 """
--> 220 logprob = factorized_joint_logprob(*args, **kwargs)
221 if not logprob:
222 return None
~/projects/code/python/aeppl/aeppl/joint_logprob.py in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
157 q_rv_value_vars = [
158 replacements[q_rv_var]
--> 159 for q_rv_var in outputs
160 if not getattr(q_rv_var.tag, "ignore_logprob", False)
161 ]
~/projects/code/python/aeppl/aeppl/joint_logprob.py in <listcomp>(.0)
158 replacements[q_rv_var]
159 for q_rv_var in outputs
--> 160 if not getattr(q_rv_var.tag, "ignore_logprob", False)
161 ]
162
KeyError: 'for{cpu,scan_fn}.0\nThe following two sections summarize the two main features we need to implement before merging. Add
|
d8434f6 to
18a92b9
Compare
18a92b9 to
375f098
Compare
|
I just added a commit that handles the first remaining issue (i.e. the section in #24 (comment) titled "Add |
4d27b62 to
5011a9d
Compare
|
I just added the functionality for the last missing feature. It needs more testing, but at least it seems to work for the one simple implemented test case. |
5011a9d to
7dd1b29
Compare
|
All right, I'm going to merge this for now; we can put in updates as we come across issues and such. |
This PR aims to implements
logprobfor Scan Ops. Linked Issue: #23Note : I am building this on top of replacement logic from #19
Scanreplacement framework fromsymbolic-pymc.logprob_ScanRV.convert_outer_out_to_infromlogprob_ScanRVnew_outer_input_varsmapping to act as values for log probability of respective inner-graph output nodes.valueof each node according to iterations of Scan. (i.e. taps logic)create_inner_out_logpinstead of a normal_logprobcall for each node involved.test_scan_logprobRandomStreamupdates in Aesara'sFunctionGraphsso that they do not require explicit handling inaeppl