Skip to content

Commit e349607

Browse files
Merge pull request #113 from brandonwillard/introduce-rv-logp-conversion
Introduce a function that converts RandomVariables to log-likelihoods
2 parents b0d8db8 + aff0206 commit e349607

File tree

3 files changed

+184
-9
lines changed

3 files changed

+184
-9
lines changed

symbolic_pymc/theano/pymc3.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from theano.gof.op import get_test_value
1919
from theano.gof.graph import Apply, inputs as tt_inputs
20+
from theano.scan_module.scan_op import Scan
21+
from theano.scan_module.scan_utils import clone
2022

2123
from .random_variables import (
2224
observed,
@@ -96,6 +98,122 @@ def convert_rv_to_dist(node, obs):
9698
return dist_type(rv.name, shape=shape, observed=obs, **dist_params)
9799

98100

101+
@dispatch(tt.TensorVariable, object)
102+
def logp(var, obs):
103+
104+
node = var.owner
105+
106+
if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"):
107+
shape = list(node.fgraph.shape_feature.shape_tuple(var))
108+
else:
109+
shape = list(var.shape)
110+
111+
for i, s in enumerate(shape):
112+
try:
113+
shape[i] = tt.get_scalar_constant_value(s)
114+
except tt.NotScalarConstantError:
115+
shape[i] = s.tag.test_value
116+
117+
logp_fn = _logp_fn(node.op, node, shape)
118+
return logp_fn(obs)
119+
120+
121+
@dispatch(RandomVariable, Apply, object)
122+
def _logp_fn(op, node, shape=None):
123+
dist_type, dist_params = _convert_rv_to_dist(op, node)
124+
if shape is not None:
125+
dist_params["shape"] = shape
126+
res = dist_type.dist(**dist_params)
127+
# Add extra information to the PyMC3 `Distribution` object
128+
res.dist_params = dist_params
129+
res.ndim_supp = op.ndim_supp
130+
# TODO: Need to maintain the order of these so that they correspond with
131+
# the `Distribution`'s parameters
132+
res.ndims_params = op.ndims_params
133+
return res.logp
134+
135+
136+
@_logp_fn.register(Scan, Apply, object)
137+
def _logp_fn_Scan(op, scan_node, shape=None):
138+
139+
scan_inner_inputs = scan_node.op.inputs
140+
scan_inner_outputs = scan_node.op.outputs
141+
142+
def create_obs_var(i, x):
143+
obs = x.type()
144+
obs.name = f"{x.name or x.owner.op.name}_obs_{i}"
145+
if hasattr(x.tag, "test_value"):
146+
obs.tag.test_value = x.tag.test_value
147+
return obs
148+
149+
rv_outs = [
150+
(i, x, create_obs_var(i, x))
151+
for i, x in enumerate(scan_inner_outputs)
152+
if x.owner and isinstance(x.owner.op, RandomVariable)
153+
]
154+
rv_inner_out_idx, rv_out_vars, rv_out_obs = zip(*rv_outs)
155+
# rv_outer_out_idx = [scan_node.op.get_oinp_iinp_iout_oout_mappings()['outer_out_from_inner_out'][i] for i in rv_inner_out_idx]
156+
# rv_outer_outputs = [scan_node.outputs[i] for i in rv_outer_out_idx]
157+
158+
logp_inner_outputs = [clone(logp(rv, obs)) for i, rv, obs in rv_outs]
159+
assert all(o in tt.gof.graph.inputs(logp_inner_outputs) for o in rv_out_obs)
160+
161+
logp_inner_outputs_inputs = tt.gof.graph.inputs(logp_inner_outputs)
162+
rv_relevant_inner_input_idx, rv_relevant_inner_inputs = zip(
163+
*[(n, i) for n, i in enumerate(scan_inner_inputs) if i in logp_inner_outputs_inputs]
164+
)
165+
logp_inner_inputs = list(rv_out_obs) + list(rv_relevant_inner_inputs)
166+
167+
# We need to create outer-inputs that represent arrays of observations
168+
# for each random variable.
169+
# To do that, we're going to use each random variable's outer-output term,
170+
# since they necessarily have the same shape and type as the observations
171+
# arrays.
172+
173+
# Just like we did for the inner-inputs, we need to get only the outer-inputs
174+
# that are relevant to the new logp graphs.
175+
# We can do that by removing the irrelevant outer-inputs using the known relevant inner-inputs
176+
removed_inner_inputs = set(range(len(scan_inner_inputs))) - set(rv_relevant_inner_input_idx)
177+
old_in_out_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
178+
rv_removed_outer_input_idx = [
179+
old_in_out_mappings["outer_inp_from_inner_inp"][i] for i in removed_inner_inputs
180+
]
181+
rv_removed_outer_inputs = [scan_node.inputs[i] for i in rv_removed_outer_input_idx]
182+
183+
rv_relevant_outer_inputs = [r for r in scan_node.inputs if r not in rv_removed_outer_inputs]
184+
185+
# Now, we can create a new op with our new inner-graph inputs and outputs.
186+
# Also, since our inner graph has new placeholder terms representing
187+
# an observed value for each random variable, we need to update the
188+
# "info" `dict`.
189+
logp_info = scan_node.op.info.copy()
190+
logp_info["tap_array"] = []
191+
logp_info["n_seqs"] += len(rv_out_obs)
192+
logp_info["n_mit_mot"] = 0
193+
logp_info["n_mit_mot_outs"] = 0
194+
logp_info["mit_mot_out_slices"] = []
195+
logp_info["n_mit_sot"] = 0
196+
logp_info["n_sit_sot"] = 0
197+
logp_info["n_shared_outs"] = 0
198+
logp_info["n_nit_sot"] += len(rv_out_obs) - 1
199+
logp_info["name"] = None
200+
logp_info["strict"] = True
201+
202+
# These are the tensor variables corresponding to each random variable's
203+
# array of observations.
204+
def logp_fn(*obs):
205+
logp_obs_outer_inputs = list(obs) # [r.clone() for r in rv_outer_outputs]
206+
logp_outer_inputs = (
207+
[rv_relevant_outer_inputs[0]] + logp_obs_outer_inputs + rv_relevant_outer_inputs[1:]
208+
)
209+
logp_op = Scan(logp_inner_inputs, logp_inner_outputs, logp_info)
210+
scan_logp = logp_op(*logp_outer_inputs)
211+
return scan_logp
212+
213+
# logp_fn = OpFromGraph(logp_obs_outer_inputs, [scan_logp])
214+
return logp_fn
215+
216+
99217
@dispatch(pm.Uniform, object)
100218
def convert_dist_to_rv(dist, rng):
101219
size = dist.shape.astype(int)[UniformRV.ndim_supp :]
@@ -467,7 +585,7 @@ def model_graph(pymc_model, output_vars=None, rand_state=None, attach_memo=True)
467585
return model_fg
468586

469587

470-
def graph_model(graph, *model_args, **model_kwargs):
588+
def graph_model(graph, *model_args, generate_names=False, **model_kwargs):
471589
"""Create a PyMC3 model from a Theano graph with `RandomVariable` nodes."""
472590
model = pm.Model(*model_args, **model_kwargs)
473591

@@ -478,13 +596,14 @@ def graph_model(graph, *model_args, **model_kwargs):
478596
nodes = [n for n in fgraph.toposort() if isinstance(n.op, RandomVariable)]
479597
rv_replacements = {}
480598

599+
node_id = 0
600+
481601
for node in nodes:
482602

483603
obs = get_rv_observation(node)
484604

485605
if obs is not None:
486606
obs = obs.inputs[0]
487-
488607
obs = tt_get_values(obs)
489608

490609
old_rv_var = node.default_output()
@@ -500,6 +619,12 @@ def graph_model(graph, *model_args, **model_kwargs):
500619
if op != node
501620
)
502621

622+
if generate_names and rv_var.name is None:
623+
node_name = "{}_{}".format(node.op.name, node_id)
624+
# warn("Name {} generated for node {}.".format(node, node_name))
625+
node_id += 1
626+
rv_var.name = node_name
627+
503628
with model:
504629
rv = convert_rv_to_dist(node, obs)
505630

symbolic_pymc/theano/random_variables.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,8 @@ def make_node(self, val, rv=None):
447447
The distribution from which `val` is assumed to be a sample value.
448448
"""
449449
val = tt.as_tensor_variable(val)
450-
if rv:
451-
if rv.owner and not isinstance(rv.owner.op, RandomVariable):
452-
raise ValueError(f"`rv` must be a RandomVariable type: {rv}")
453-
454-
if rv.type.convert_variable(val) is None:
450+
if rv is not None:
451+
if not hasattr(rv, "type") or rv.type.convert_variable(val) is None:
455452
raise ValueError(
456453
("`rv` and `val` do not have compatible types:" f" rv={rv}, val={val}")
457454
)

tests/theano/test_pymc3.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# from theano.configparser import change_flags
1414
from theano.gof.graph import inputs as tt_inputs
1515

16-
from symbolic_pymc.theano.random_variables import MvNormalRV, Observed, observed
16+
from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, Observed, observed
1717
from symbolic_pymc.theano.ops import RandomVariable
1818
from symbolic_pymc.theano.opt import FunctionGraph
19-
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
19+
from symbolic_pymc.theano.pymc3 import model_graph, graph_model, logp
2020
from symbolic_pymc.theano.utils import canonicalize
2121
from symbolic_pymc.theano.meta import mt
2222

@@ -60,6 +60,12 @@ def test_pymc3_convert_dists():
6060
new_pymc_rv_names = {n.name for n in pymc_model.observed_RVs}
6161
pymc_rv_names == new_pymc_rv_names
6262

63+
with pytest.raises(TypeError):
64+
graph_model(NormalRV(0, 1), generate_names=False)
65+
66+
res = graph_model(NormalRV(0, 1), generate_names=True)
67+
assert res.vars[0].name == "normal_0"
68+
6369

6470
def test_pymc3_normal_model():
6571
"""Conduct a more in-depth test of PyMC3/Theano conversions for a specific model."""
@@ -246,3 +252,50 @@ def test_pymc3_broadcastable():
246252
Z_rv_meta = canonicalize(Z_rv_obs_.reify(), return_graph=False)
247253

248254
assert mt(Z_rv_tt) == mt(Z_rv_meta)
255+
256+
257+
def test_logp():
258+
test_rv = NormalRV(0, tt.arange(1, 3))
259+
test_logp = logp(test_rv, 0)
260+
261+
assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval())
262+
263+
fgraph = FunctionGraph(tt_inputs([test_rv]), [test_rv], features=[tt.opt.ShapeFeature()])
264+
test_rv.owner.fgraph = fgraph
265+
test_logp = logp(test_rv, 0)
266+
267+
assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval())
268+
269+
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
270+
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
271+
rng_tt.tag.is_rng = True
272+
rng_tt.default_update = rng_tt
273+
274+
# TODO: Scan of univariate normals.
275+
N_tt = tt.iscalar("N")
276+
N_tt.tag.test_value = 10
277+
278+
mus_tt = tt.arange(N_tt)
279+
mus_tt.tag.test_value
280+
281+
sigmas_tt = tt.ones((N_tt,))
282+
sigmas_tt.tag.test_value
283+
284+
def scan_fn(mu_t, sigma_t, rng):
285+
# mix = np.stack([NormalRV(mu_t, sigma_t, rng=rng), GammaRV(mu_t**2 / sigma_t**2, mu_t / sigma_t)])
286+
return NormalRV(mu_t, sigma_t, rng=rng)
287+
288+
scan_rv, _ = theano.scan(
289+
fn=scan_fn,
290+
sequences=[mus_tt, sigmas_tt],
291+
non_sequences=[rng_tt],
292+
outputs_info=[{},],
293+
strict=True,
294+
name="scan_rv",
295+
)
296+
297+
scan_logp = logp(scan_rv, tt.zeros((N_tt,)))
298+
res = scan_logp.eval({N_tt: 10})
299+
exp_res = pm.Normal.dist(np.arange(10), np.ones(10)).logp(np.zeros(10)).eval()
300+
301+
assert np.array_equal(res, exp_res)

0 commit comments

Comments
 (0)