From aff020642e0959c6a6159b6e7c33006b524ffdfb Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 9 Jul 2020 01:53:32 -0500 Subject: [PATCH] Introduce a function that converts RandomVariables to log-likelihoods --- symbolic_pymc/theano/pymc3.py | 129 ++++++++++++++++++++++- symbolic_pymc/theano/random_variables.py | 7 +- tests/theano/test_pymc3.py | 57 +++++++++- 3 files changed, 184 insertions(+), 9 deletions(-) diff --git a/symbolic_pymc/theano/pymc3.py b/symbolic_pymc/theano/pymc3.py index 15c2912..b7d0c60 100644 --- a/symbolic_pymc/theano/pymc3.py +++ b/symbolic_pymc/theano/pymc3.py @@ -17,6 +17,8 @@ from theano.gof.op import get_test_value from theano.gof.graph import Apply, inputs as tt_inputs +from theano.scan_module.scan_op import Scan +from theano.scan_module.scan_utils import clone from .random_variables import ( observed, @@ -96,6 +98,122 @@ def convert_rv_to_dist(node, obs): return dist_type(rv.name, shape=shape, observed=obs, **dist_params) +@dispatch(tt.TensorVariable, object) +def logp(var, obs): + + node = var.owner + + if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"): + shape = list(node.fgraph.shape_feature.shape_tuple(var)) + else: + shape = list(var.shape) + + for i, s in enumerate(shape): + try: + shape[i] = tt.get_scalar_constant_value(s) + except tt.NotScalarConstantError: + shape[i] = s.tag.test_value + + logp_fn = _logp_fn(node.op, node, shape) + return logp_fn(obs) + + +@dispatch(RandomVariable, Apply, object) +def _logp_fn(op, node, shape=None): + dist_type, dist_params = _convert_rv_to_dist(op, node) + if shape is not None: + dist_params["shape"] = shape + res = dist_type.dist(**dist_params) + # Add extra information to the PyMC3 `Distribution` object + res.dist_params = dist_params + res.ndim_supp = op.ndim_supp + # TODO: Need to maintain the order of these so that they correspond with + # the `Distribution`'s parameters + res.ndims_params = op.ndims_params + return res.logp + + +@_logp_fn.register(Scan, Apply, object) +def _logp_fn_Scan(op, scan_node, shape=None): + + scan_inner_inputs = scan_node.op.inputs + scan_inner_outputs = scan_node.op.outputs + + def create_obs_var(i, x): + obs = x.type() + obs.name = f"{x.name or x.owner.op.name}_obs_{i}" + if hasattr(x.tag, "test_value"): + obs.tag.test_value = x.tag.test_value + return obs + + rv_outs = [ + (i, x, create_obs_var(i, x)) + for i, x in enumerate(scan_inner_outputs) + if x.owner and isinstance(x.owner.op, RandomVariable) + ] + rv_inner_out_idx, rv_out_vars, rv_out_obs = zip(*rv_outs) + # rv_outer_out_idx = [scan_node.op.get_oinp_iinp_iout_oout_mappings()['outer_out_from_inner_out'][i] for i in rv_inner_out_idx] + # rv_outer_outputs = [scan_node.outputs[i] for i in rv_outer_out_idx] + + logp_inner_outputs = [clone(logp(rv, obs)) for i, rv, obs in rv_outs] + assert all(o in tt.gof.graph.inputs(logp_inner_outputs) for o in rv_out_obs) + + logp_inner_outputs_inputs = tt.gof.graph.inputs(logp_inner_outputs) + rv_relevant_inner_input_idx, rv_relevant_inner_inputs = zip( + *[(n, i) for n, i in enumerate(scan_inner_inputs) if i in logp_inner_outputs_inputs] + ) + logp_inner_inputs = list(rv_out_obs) + list(rv_relevant_inner_inputs) + + # We need to create outer-inputs that represent arrays of observations + # for each random variable. + # To do that, we're going to use each random variable's outer-output term, + # since they necessarily have the same shape and type as the observations + # arrays. + + # Just like we did for the inner-inputs, we need to get only the outer-inputs + # that are relevant to the new logp graphs. + # We can do that by removing the irrelevant outer-inputs using the known relevant inner-inputs + removed_inner_inputs = set(range(len(scan_inner_inputs))) - set(rv_relevant_inner_input_idx) + old_in_out_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings() + rv_removed_outer_input_idx = [ + old_in_out_mappings["outer_inp_from_inner_inp"][i] for i in removed_inner_inputs + ] + rv_removed_outer_inputs = [scan_node.inputs[i] for i in rv_removed_outer_input_idx] + + rv_relevant_outer_inputs = [r for r in scan_node.inputs if r not in rv_removed_outer_inputs] + + # Now, we can create a new op with our new inner-graph inputs and outputs. + # Also, since our inner graph has new placeholder terms representing + # an observed value for each random variable, we need to update the + # "info" `dict`. + logp_info = scan_node.op.info.copy() + logp_info["tap_array"] = [] + logp_info["n_seqs"] += len(rv_out_obs) + logp_info["n_mit_mot"] = 0 + logp_info["n_mit_mot_outs"] = 0 + logp_info["mit_mot_out_slices"] = [] + logp_info["n_mit_sot"] = 0 + logp_info["n_sit_sot"] = 0 + logp_info["n_shared_outs"] = 0 + logp_info["n_nit_sot"] += len(rv_out_obs) - 1 + logp_info["name"] = None + logp_info["strict"] = True + + # These are the tensor variables corresponding to each random variable's + # array of observations. + def logp_fn(*obs): + logp_obs_outer_inputs = list(obs) # [r.clone() for r in rv_outer_outputs] + logp_outer_inputs = ( + [rv_relevant_outer_inputs[0]] + logp_obs_outer_inputs + rv_relevant_outer_inputs[1:] + ) + logp_op = Scan(logp_inner_inputs, logp_inner_outputs, logp_info) + scan_logp = logp_op(*logp_outer_inputs) + return scan_logp + + # logp_fn = OpFromGraph(logp_obs_outer_inputs, [scan_logp]) + return logp_fn + + @dispatch(pm.Uniform, object) def convert_dist_to_rv(dist, rng): size = dist.shape.astype(int)[UniformRV.ndim_supp :] @@ -467,7 +585,7 @@ def model_graph(pymc_model, output_vars=None, rand_state=None, attach_memo=True) return model_fg -def graph_model(graph, *model_args, **model_kwargs): +def graph_model(graph, *model_args, generate_names=False, **model_kwargs): """Create a PyMC3 model from a Theano graph with `RandomVariable` nodes.""" model = pm.Model(*model_args, **model_kwargs) @@ -478,13 +596,14 @@ def graph_model(graph, *model_args, **model_kwargs): nodes = [n for n in fgraph.toposort() if isinstance(n.op, RandomVariable)] rv_replacements = {} + node_id = 0 + for node in nodes: obs = get_rv_observation(node) if obs is not None: obs = obs.inputs[0] - obs = tt_get_values(obs) old_rv_var = node.default_output() @@ -500,6 +619,12 @@ def graph_model(graph, *model_args, **model_kwargs): if op != node ) + if generate_names and rv_var.name is None: + node_name = "{}_{}".format(node.op.name, node_id) + # warn("Name {} generated for node {}.".format(node, node_name)) + node_id += 1 + rv_var.name = node_name + with model: rv = convert_rv_to_dist(node, obs) diff --git a/symbolic_pymc/theano/random_variables.py b/symbolic_pymc/theano/random_variables.py index 0e87f8b..dc1a719 100644 --- a/symbolic_pymc/theano/random_variables.py +++ b/symbolic_pymc/theano/random_variables.py @@ -447,11 +447,8 @@ def make_node(self, val, rv=None): The distribution from which `val` is assumed to be a sample value. """ val = tt.as_tensor_variable(val) - if rv: - if rv.owner and not isinstance(rv.owner.op, RandomVariable): - raise ValueError(f"`rv` must be a RandomVariable type: {rv}") - - if rv.type.convert_variable(val) is None: + if rv is not None: + if not hasattr(rv, "type") or rv.type.convert_variable(val) is None: raise ValueError( ("`rv` and `val` do not have compatible types:" f" rv={rv}, val={val}") ) diff --git a/tests/theano/test_pymc3.py b/tests/theano/test_pymc3.py index d1546f7..a9ac339 100644 --- a/tests/theano/test_pymc3.py +++ b/tests/theano/test_pymc3.py @@ -13,10 +13,10 @@ # from theano.configparser import change_flags from theano.gof.graph import inputs as tt_inputs -from symbolic_pymc.theano.random_variables import MvNormalRV, Observed, observed +from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, Observed, observed from symbolic_pymc.theano.ops import RandomVariable from symbolic_pymc.theano.opt import FunctionGraph -from symbolic_pymc.theano.pymc3 import model_graph, graph_model +from symbolic_pymc.theano.pymc3 import model_graph, graph_model, logp from symbolic_pymc.theano.utils import canonicalize from symbolic_pymc.theano.meta import mt @@ -60,6 +60,12 @@ def test_pymc3_convert_dists(): new_pymc_rv_names = {n.name for n in pymc_model.observed_RVs} pymc_rv_names == new_pymc_rv_names + with pytest.raises(TypeError): + graph_model(NormalRV(0, 1), generate_names=False) + + res = graph_model(NormalRV(0, 1), generate_names=True) + assert res.vars[0].name == "normal_0" + def test_pymc3_normal_model(): """Conduct a more in-depth test of PyMC3/Theano conversions for a specific model.""" @@ -246,3 +252,50 @@ def test_pymc3_broadcastable(): Z_rv_meta = canonicalize(Z_rv_obs_.reify(), return_graph=False) assert mt(Z_rv_tt) == mt(Z_rv_meta) + + +def test_logp(): + test_rv = NormalRV(0, tt.arange(1, 3)) + test_logp = logp(test_rv, 0) + + assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval()) + + fgraph = FunctionGraph(tt_inputs([test_rv]), [test_rv], features=[tt.opt.ShapeFeature()]) + test_rv.owner.fgraph = fgraph + test_logp = logp(test_rv, 0) + + assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval()) + + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + # TODO: Scan of univariate normals. + N_tt = tt.iscalar("N") + N_tt.tag.test_value = 10 + + mus_tt = tt.arange(N_tt) + mus_tt.tag.test_value + + sigmas_tt = tt.ones((N_tt,)) + sigmas_tt.tag.test_value + + def scan_fn(mu_t, sigma_t, rng): + # mix = np.stack([NormalRV(mu_t, sigma_t, rng=rng), GammaRV(mu_t**2 / sigma_t**2, mu_t / sigma_t)]) + return NormalRV(mu_t, sigma_t, rng=rng) + + scan_rv, _ = theano.scan( + fn=scan_fn, + sequences=[mus_tt, sigmas_tt], + non_sequences=[rng_tt], + outputs_info=[{},], + strict=True, + name="scan_rv", + ) + + scan_logp = logp(scan_rv, tt.zeros((N_tt,))) + res = scan_logp.eval({N_tt: 10}) + exp_res = pm.Normal.dist(np.arange(10), np.ones(10)).logp(np.zeros(10)).eval() + + assert np.array_equal(res, exp_res)