Skip to content

Commit 2e22e87

Browse files
Introduce RandomVariable push-out optimization
1 parent e349607 commit 2e22e87

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

symbolic_pymc/theano/opt.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from functools import wraps
77

8-
from theano.gof.opt import LocalOptimizer
8+
from theano.gof.opt import LocalOptimizer, local_optimizer
9+
from theano.scan_module.scan_op import Scan
10+
from theano.scan_module.scan_utils import scan_args as ScanArgs
911

1012
from unification import var, variables
1113

@@ -14,6 +16,7 @@
1416
from etuples.core import ExpressionTuple
1517

1618
from .meta import MetaSymbol
19+
from .ops import RandomVariable
1720

1821

1922
def eval_and_reify_meta(x):
@@ -219,3 +222,54 @@ def transform(self, node):
219222
return new_node
220223
else:
221224
return False
225+
226+
227+
@local_optimizer([Scan])
228+
def push_out_rvs_from_scan(node):
229+
"""Push `RandomVariable`s out of `Scan` nodes.
230+
231+
When `RandomVariable`s are created within the inner-graph of a `Scan` and
232+
are not output to the outer-graph, we "push" them out of the inner-graph.
233+
This helps us produce an outer-graph in which all the relevant `RandomVariable`s
234+
are accessible (e.g. for constructing a log-likelihood graph).
235+
"""
236+
if not isinstance(node.op, Scan):
237+
return False
238+
239+
scan_args = ScanArgs(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info)
240+
241+
# Find the un-output `RandomVariable`s created in the inner-graph
242+
clients = {}
243+
local_fgraph_topo = theano.gof.graph.io_toposort(
244+
scan_args.inner_inputs, scan_args.inner_outputs, clients=clients
245+
)
246+
unpushed_inner_rvs = []
247+
for n in local_fgraph_topo:
248+
if isinstance(n.op, RandomVariable):
249+
unpushed_inner_rvs.extend([c for c in clients[n] if c not in scan_args.inner_outputs])
250+
251+
if len(unpushed_inner_rvs) == 0:
252+
return False
253+
254+
# Add the new outputs to the inner and outer graphs
255+
scan_args.inner_out_nit_sot.extend(unpushed_inner_rvs)
256+
257+
if len(scan_args.outer_in_nit_sot) == 0:
258+
raise ValueError("No outer-graph inputs nit-sots!")
259+
260+
# Just like `theano.scan`, we simply copy/repeat the existing nit-sot
261+
# outer-graph input value, which represents the actual size of the output
262+
# tensors. Apparently, the value needs to be duplicated for all nit-sots.
263+
# FYI: This is what increments the nit-sot values in `scan_args.info`, as
264+
# well.
265+
# TODO: Can we just use `scan_args.n_steps`?
266+
scan_args.outer_in_nit_sot.extend(scan_args.outer_in_nit_sot[0:1] * len(unpushed_inner_rvs))
267+
268+
op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info)
269+
outputs = list(op(*scan_args.outer_inputs))
270+
271+
# Return only the replacements for the original `node.outputs`
272+
new_inner_out_idx = [scan_args.inner_outputs.index(i) for i in unpushed_inner_rvs]
273+
_ = [outputs.pop(op.var_mappings["outer_out_from_inner_out"][i]) for i in new_inner_out_idx]
274+
275+
return dict(zip(node.outputs, outputs))

tests/theano/test_opt.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
import theano
13
import theano.tensor as tt
24

35
from unification import var
@@ -9,10 +11,16 @@
911

1012
from theano.gof.opt import EquilibriumOptimizer
1113
from theano.gof.graph import inputs as tt_inputs
14+
from theano.scan_module.scan_op import Scan
1215

1316
from symbolic_pymc.theano.meta import mt
14-
from symbolic_pymc.theano.opt import KanrenRelationSub, FunctionGraph
17+
from symbolic_pymc.theano.opt import (
18+
KanrenRelationSub,
19+
FunctionGraph,
20+
push_out_rvs_from_scan,
21+
)
1522
from symbolic_pymc.theano.utils import optimize_graph
23+
from symbolic_pymc.theano.random_variables import CategoricalRV, DirichletRV, NormalRV
1624

1725

1826
def test_kanren_opt():
@@ -58,3 +66,64 @@ def distributes(in_lv, out_lv):
5866
assert fgraph_opt.owner.inputs[1].owner.op == tt.add
5967
assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[0].owner.op, tt.Dot)
6068
assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[1].owner.op, tt.Dot)
69+
70+
71+
def test_push_out_rvs():
72+
theano.config.cxx = ""
73+
theano.config.mode = "FAST_COMPILE"
74+
tt.config.compute_test_value = "warn"
75+
76+
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
77+
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
78+
rng_tt.tag.is_rng = True
79+
rng_tt.default_update = rng_tt
80+
81+
N_tt = tt.iscalar("N")
82+
N_tt.tag.test_value = 10
83+
M_tt = tt.iscalar("M")
84+
M_tt.tag.test_value = 2
85+
86+
mus_tt = tt.matrix("mus_t")
87+
mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype(
88+
theano.config.floatX
89+
)
90+
91+
sigmas_tt = tt.ones((N_tt,))
92+
Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
93+
94+
# In this case, `Y_t` depends on `S_t` and `S_t` is not output. Our
95+
# push-out optimization should create a new `Scan` that also outputs each
96+
# `S_t`.
97+
def scan_fn(mus_t, sigma_t, Gamma_t, rng):
98+
S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t")
99+
Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
100+
return Y_t
101+
102+
Y_rv, _ = theano.scan(
103+
fn=scan_fn,
104+
sequences=[mus_tt, sigmas_tt],
105+
non_sequences=[Gamma_rv, rng_tt],
106+
outputs_info=[{}],
107+
strict=True,
108+
name="scan_rv",
109+
)
110+
Y_rv.name = "Y_rv"
111+
112+
orig_scan_op = Y_rv.owner.op
113+
assert len(Y_rv.owner.outputs) == 2
114+
assert isinstance(orig_scan_op, Scan)
115+
assert len(orig_scan_op.outputs) == 2
116+
assert orig_scan_op.outputs[0].owner.op == NormalRV
117+
assert isinstance(orig_scan_op.outputs[1].type, tt.raw_random.RandomStateType)
118+
119+
fgraph = FunctionGraph(tt_inputs([Y_rv]), [Y_rv], clone=True)
120+
pushoutrvs_opt = EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10)
121+
fgraph_opt = optimize_graph(fgraph, pushoutrvs_opt, return_graph=True)
122+
123+
# There should now be a new output for all the `S_t`
124+
new_scan = fgraph_opt.outputs[0].owner
125+
assert len(new_scan.outputs) == 3
126+
assert isinstance(new_scan.op, Scan)
127+
assert new_scan.op.outputs[0].owner.op == NormalRV
128+
assert new_scan.op.outputs[1].owner.op == CategoricalRV
129+
assert isinstance(new_scan.op.outputs[2].type, tt.raw_random.RandomStateType)

0 commit comments

Comments
 (0)