diff --git a/symbolic_pymc/theano/opt.py b/symbolic_pymc/theano/opt.py index 00c545d..c14fdbd 100644 --- a/symbolic_pymc/theano/opt.py +++ b/symbolic_pymc/theano/opt.py @@ -1,11 +1,19 @@ import types +import numpy as np + import theano import theano.tensor as tt +from copy import copy from functools import wraps +from unittest.mock import patch +from collections import namedtuple, OrderedDict -from theano.gof.opt import LocalOptimizer +from theano.gof.opt import LocalOptimizer, local_optimizer +from theano.gof.graph import inputs as tt_inputs +from theano.scan_module.scan_op import Scan +from theano.scan_module.scan_utils import scan_args, clone as tt_clone from unification import var, variables @@ -14,6 +22,7 @@ from etuples.core import ExpressionTuple from .meta import MetaSymbol +from .ops import RandomVariable def eval_and_reify_meta(x): @@ -33,6 +42,13 @@ def eval_and_reify_meta(x): return res +def safe_index(lst, x): + try: + return lst.index(x) + except ValueError: + return None + + class FunctionGraph(theano.gof.fg.FunctionGraph): """A version of `FunctionGraph` that knows not to merge non-deterministic `Op`s. @@ -63,6 +79,17 @@ def __init__( inputs = [self.memo[i] for i in inputs] outputs = [self.memo[o] for o in outputs] + else: + # We make it possible to use a non-cloned set of inputs and outputs + # by cloning only the cached constants and replacing them in the + # inputs and output graphs. + cached_constants = [x for x in inputs if getattr(x, "cached", False)] + copied_constants = tt_clone(cached_constants, share_inputs=False) + replacements = list(zip(cached_constants, copied_constants)) + inputs = list(set(inputs) - set(cached_constants)) + list(copied_constants) + outputs = tt_clone(outputs, share_inputs=True, replace=replacements) + + assert not any(getattr(v, "cached", False) for v in inputs + outputs) super().__init__(inputs, outputs, features=features, clone=False, update_mapping=None) @@ -214,8 +241,502 @@ def transform(self, node): else: raise ValueError( "Unsupported FunctionGraph replacement variable type: {chosen_res}" - ) + ) # pragma: no cover return new_node else: return False + + +FieldInfo = namedtuple("FieldInfo", ("name", "agg_name", "index", "inner_index", "agg_index")) + + +class ScanArgs(scan_args): + """An improved version of `theano.scan_module.scan_utils`.""" + + default_filter = lambda x: x.startswith("inner_") or x.startswith("outer_") + nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot") + + def __init__(self, *args, **kwargs): + # Prevent unnecessary and counter-productive cloning. + # If you want to clone the inner graph, do it before you call this! + with patch( + "theano.scan_module.scan_utils.reconstruct_graph", + side_effect=lambda x, y, z=None: [x, y], + ): + super().__init__(*args, **kwargs) + + @staticmethod + def from_node(node): + if not isinstance(node.op, Scan): + raise TypeError("{} is not a Scan node".format(node)) + return ScanArgs(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info) + + @classmethod + def create_empty(cls): + info = OrderedDict( + [ + ("n_seqs", 0), + ("n_mit_mot", 0), + ("n_mit_sot", 0), + ("tap_array", []), + ("n_sit_sot", 0), + ("n_nit_sot", 0), + ("n_shared_outs", 0), + ("n_mit_mot_outs", 0), + ("mit_mot_out_slices", []), + ("truncate_gradient", -1), + ("name", None), + ("mode", None), + ("destroy_map", OrderedDict()), + ("gpua", False), + ("as_while", False), + ("profile", False), + ("allow_gc", False), + ] + ) + res = cls([1], [], [], [], info) + res.n_steps = None + return res + + @property + def n_nit_sot(self): + # This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings` + return self.info["n_nit_sot"] + + @property + def inputs(self): + # This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings` + return self.inner_inputs + + @property + def n_mit_mot(self): + # This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings` + return self.info["n_mit_mot"] + + @property + def var_mappings(self): + return Scan.get_oinp_iinp_iout_oout_mappings(self) + + @property + def field_names(self): + res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"] + res.extend( + [ + attr + for attr in self.__dict__ + if attr.startswith("inner_in") + or attr.startswith("inner_out") + or attr.startswith("outer_in") + or attr.startswith("outer_out") + or attr == "n_steps" + ] + ) + return res + + def get_alt_field(self, var_info, alt_prefix): + """Get the alternate input/output field for a given element of `ScanArgs`. + + For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then + `get_alt_field(var_info, "inner_out")` returns the element corresponding + `var_info` in `ScanArgs.inner_out_sit_sot`. + + Parameters + ---------- + var_info: TensorVariable or FieldInfo + The element for which we want the alternate + alt_prefix: str + The string prefix for the alternate field type. It can be one of + the following: "inner_out", "inner_in", "outer_in", and "outer_out". + + Outputs + ------- + TensorVariable + Returns the alternate variable. + + """ + if not isinstance(var_info, FieldInfo): + var_info = self.find_among_fields(var_info) + + alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :] + alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index] + return alt_var + + def find_among_fields(self, i, field_filter=default_filter): + """Find the type and indices of the field containing a given element. + + NOTE: This only returns the *first* field containing the given element. + + Parameters + ---------- + i: theano.gof.graph.Variable + The element to find among this object's fields. + field_filter: function + A function passed to `filter` that determines which fields to + consider. It must take a string field name and return a truthy + value. + + Returns + ------- + A tuple of length 4 containing the field name string, the first index, + the second index (for nested lists), and the "major" index (i.e. the + index within the aggregate lists like `self.inner_inputs`, + `self.outer_outputs`, etc.), or a triple of `None` when no match is + found. + + """ + + field_names = filter(field_filter, self.field_names) + + for field_name in field_names: + lst = getattr(self, field_name) + + field_prefix = field_name[:8] + if field_prefix.endswith("in"): + agg_field_name = "{}puts".format(field_prefix) + else: + agg_field_name = "{}tputs".format(field_prefix) + + agg_list = getattr(self, agg_field_name) + + if field_name in self.nested_list_fields: + for n, sub_lst in enumerate(lst): + idx = safe_index(sub_lst, i) + if idx is not None: + agg_idx = safe_index(agg_list, i) + return FieldInfo(field_name, agg_field_name, n, idx, agg_idx) + else: + idx = safe_index(lst, i) + if idx is not None: + agg_idx = safe_index(agg_list, i) + return FieldInfo(field_name, agg_field_name, idx, None, agg_idx) + + return None + + def _remove_from_fields(self, i, field_filter=default_filter): + + field_info = self.find_among_fields(i, field_filter=field_filter) + + if field_info is None: + return None + + if field_info.inner_index is not None: + getattr(self, field_info.name)[field_info.index].remove(i) + else: + getattr(self, field_info.name).remove(i) + + return field_info + + def get_dependent_nodes(self, i, seen=None): + + if seen is None: + seen = {i} + else: + seen.add(i) + + var_mappings = self.var_mappings + + field_info = self.find_among_fields(i) + + if field_info is None: + raise ValueError("{} not found among fields.".format(i)) + + # Find the `var_mappings` key suffix that matches the field/set of + # arguments containing our source node + if field_info.name[:8].endswith("_in"): + map_key_suffix = "{}p".format(field_info.name[:8]) + else: + map_key_suffix = field_info.name[:9] + + dependent_nodes = set() + for k, v in var_mappings.items(): + + if not k.endswith(map_key_suffix): + continue + + dependent_idx = v[field_info.agg_index] + dependent_idx = dependent_idx if isinstance(dependent_idx, list) else [dependent_idx] + + # Get the `ScanArgs` field name for the aggregate list property + # corresponding to these dependent argument types (i.e. either + # "outer_inputs", "inner_inputs", "inner_outputs", or + # "outer_outputs"). + # To do this, we need to parse the "shared" prefix of the + # current `var_mappings` key and append the missing parts so that + # it either forms `"*_inputs"` or `"*_outputs"`. + to_agg_field_prefix = k[:9] + if to_agg_field_prefix.endswith("p"): + to_agg_field_name = "{}uts".format(to_agg_field_prefix) + else: + to_agg_field_name = "{}puts".format(to_agg_field_prefix) + + to_agg_field = getattr(self, to_agg_field_name) + + for d_id in dependent_idx: + if d_id < 0: + continue + + dependent_var = to_agg_field[d_id] + + if dependent_var not in seen: + dependent_nodes.add(dependent_var) + + if field_info.name.startswith("inner_in"): + # If starting from an inner-input, then we need to find any + # inner-outputs that depend on it. + for out_n in self.inner_outputs: + if i in tt_inputs([out_n]): + if out_n not in seen: + dependent_nodes.add(out_n) + + for n in tuple(dependent_nodes): + if n in seen: + continue + sub_dependent_nodes = self.get_dependent_nodes(n, seen=seen) + dependent_nodes |= sub_dependent_nodes + seen |= sub_dependent_nodes + + return dependent_nodes + + def remove_from_fields(self, i, rm_dependents=True): + + if rm_dependents: + vars_to_remove = self.get_dependent_nodes(i) | {i} + else: + vars_to_remove = {i} + + rm_info = [] + for v in vars_to_remove: + dependent_rm_info = self._remove_from_fields(v) + rm_info.append((v, dependent_rm_info)) + + return rm_info + + def __str__(self): + inner_arg_strs = [ + "\t{}={}".format(p, getattr(self, p)) + for p in self.field_names + if p.startswith("outer_in") or p == "n_steps" + ] + inner_arg_strs += [ + "\t{}={}".format(p, getattr(self, p)) + for p in self.field_names + if p.startswith("inner_in") + ] + inner_arg_strs += [ + "\tmit_mot_in_slices={}".format(self.mit_mot_in_slices), + "\tmit_sot_in_slices={}".format(self.mit_sot_in_slices), + ] + inner_arg_strs += [ + "\t{}={}".format(p, getattr(self, p)) + for p in self.field_names + if p.startswith("inner_out") + ] + inner_arg_strs += [ + "\tmit_mot_out_slices={}".format(self.mit_mot_out_slices), + ] + inner_arg_strs += [ + "\t{}={}".format(p, getattr(self, p)) + for p in self.field_names + if p.startswith("outer_out") + ] + res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs)) + return res + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + + for field_name in self.field_names: + if not hasattr(other, field_name) or getattr(self, field_name) != getattr( + other, field_name + ): + return False + + return True + + +@local_optimizer([Scan]) +def push_out_rvs_from_scan(node): + """Push `RandomVariable`s out of `Scan` nodes. + + When `RandomVariable`s are created within the inner-graph of a `Scan` and + are not output to the outer-graph, we "push" them out of the inner-graph. + This helps us produce an outer-graph in which all the relevant `RandomVariable`s + are accessible (e.g. for constructing a log-likelihood graph). + """ + scan_args = ScanArgs(node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info) + + # Find the un-output `RandomVariable`s created in the inner-graph + clients = {} + local_fgraph_topo = theano.gof.graph.io_toposort( + scan_args.inner_inputs, scan_args.inner_outputs, clients=clients + ) + unpushed_inner_rvs = [] + for n in local_fgraph_topo: + if isinstance(n.op, RandomVariable): + unpushed_inner_rvs.extend([c for c in clients[n] if c not in scan_args.inner_outputs]) + + if len(unpushed_inner_rvs) == 0: + return False + + # Add the new outputs to the inner and outer graphs + scan_args.inner_out_nit_sot.extend(unpushed_inner_rvs) + + assert len(scan_args.outer_in_nit_sot) > 0, "No outer-graph inputs are nit-sots!" + + # Just like `theano.scan`, we simply copy/repeat the existing nit-sot + # outer-graph input value, which represents the actual size of the output + # tensors. Apparently, the value needs to be duplicated for all nit-sots. + # FYI: This is what increments the nit-sot values in `scan_args.info`, as + # well. + # TODO: Can we just use `scan_args.n_steps`? + scan_args.outer_in_nit_sot.extend(scan_args.outer_in_nit_sot[0:1] * len(unpushed_inner_rvs)) + + op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info) + outputs = list(op(*scan_args.outer_inputs)) + + # Return only the replacements for the original `node.outputs` + new_inner_out_idx = [scan_args.inner_outputs.index(i) for i in unpushed_inner_rvs] + _ = [outputs.pop(op.var_mappings["outer_out_from_inner_out"][i]) for i in new_inner_out_idx] + + return dict(zip(node.outputs, outputs)) + + +def convert_outer_out_to_in(input_scan_args, var, inner_out_fn=None, output_scan_args=None): + """Convert outer-graph outputs into outer-graph inputs. + + Parameters + ---------- + input_scan_args: ScanArgs + The source scan arguments. + var: TensorVariable + The outer-graph output variable that is to be converted into an + outer-graph input. + inner_out_fn: function (Optional) + A function with the signature `(input_scan_args, old_inner_out_var, + new_inner_in_var output_scan_args)` that produces a new inner-graph + output. This can be used to transform the `var`'s + corresponding inner-graph output, for example. + input_scan_args: ScanArgs (Optional) + If this argument is non-`None`, the conversion is applied to the given + `ScanArgs` and the old `var` output is removed. + + Outputs + ------- + (ScanArgs, TensorVariable) + A tuple containing a `ScanArgs` object for a `Scan` in which `var` has been + converted to an outer-graph input, and a variable that is a clone of `var` + and serves as the new outer-graph input term. + + """ + replacing = False + if output_scan_args is None: + output_scan_args = ScanArgs.create_empty() + elif output_scan_args == input_scan_args: + replacing = True + # We will not change the input `ScanArgs` in-place + if output_scan_args is input_scan_args: + output_scan_args = copy(input_scan_args) + + var_info = input_scan_args.find_among_fields( + var, field_filter=lambda x: x.startswith("outer_out") + ) + + old_inner_out_var = input_scan_args.get_alt_field(var_info, "inner_out") + + if replacing: + output_scan_args.remove_from_fields(old_inner_out_var, rm_dependents=False) + # Remove the old outer-output variable. + # Not sure if this really matters, since we don't use the outer-outputs + # when building a new `Scan`, but doing it keeps the `ScanArgs` object + # consistent. + output_scan_args.remove_from_fields(var, rm_dependents=False) + + # Couldn't one do the same with `var_info`? + inner_out_info = input_scan_args.find_among_fields( + old_inner_out_var, field_filter=lambda x: x.startswith("inner_out") + ) + + # Use the index for the specific inner-graph sub-collection to which this + # variable belongs (e.g. index `1` among the inner-graph sit-sot terms) + var_idx = inner_out_info.index + + # The old inner-output variable becomes the a new inner-input + inner_in_var = old_inner_out_var.clone() + + # We need to clone any existing inner-output variables in the `ScanArgs` + # object that we're mutating and replace references to `old_inner_out_var` + # with `inner_in_var`. If we don't, then any other inner-outputs that + # reference the inner-output that we're replacing will be inconsistent. + # Instead, we want those other inner-outputs to reference the new + # inner-input replacement variable. + for io_var in list(output_scan_args.inner_outputs): + io_var_info = output_scan_args.find_among_fields( + io_var, field_filter=lambda x: x.startswith("inner_out") + ) + io_sub_list = getattr(output_scan_args, io_var_info.name) + + (new_io_var,) = tt_clone([io_var], replace={old_inner_out_var: inner_in_var}) + + io_sub_list[io_var_info.index] = new_io_var + + # If we're replacing a [m|s]it-sot, then we need to add a new nit-sot + add_nit_sot = False + inner_in_seqs = [inner_in_var] + if inner_out_info.name.endswith("mit_sot"): + inner_in_seqs = input_scan_args.inner_in_mit_sot[var_idx] + inner_in_seqs + if replacing: + output_scan_args.inner_in_mit_sot.pop(var_idx) + output_scan_args.outer_in_mit_sot.pop(var_idx) + add_nit_sot = True + elif inner_out_info.name.endswith("sit_sot"): + inner_in_seqs = [input_scan_args.inner_in_sit_sot[var_idx]] + inner_in_seqs + if replacing: + output_scan_args.inner_in_sit_sot.pop(var_idx) + output_scan_args.outer_in_sit_sot.pop(var_idx) + add_nit_sot = True + + taps = [0] + if inner_out_info.name.endswith("mit_sot"): + taps = input_scan_args.mit_sot_in_slices[var_idx] + taps + if replacing: + output_scan_args.mit_sot_in_slices.pop(var_idx) + elif inner_out_info.name.endswith("sit_sot"): + taps = [-1] + taps + + taps, inner_in_seqs = zip(*sorted(zip(taps, inner_in_seqs), key=lambda x: x[0])) + + inner_in_seqs = list(reversed(inner_in_seqs)) + output_scan_args.inner_in_seqs += inner_in_seqs + + taps = np.asarray(taps) + slice_seqs = zip(-taps, [n if n < 0 else None for n in reversed(taps)]) + + # We could clone `var`, but reusing it will make things easier down the + # line (e.g. avoid the need to remap cloned variables) + # new_input_var = var.clone() + new_outer_input_var = var.clone() + if new_outer_input_var.name: + new_outer_input_var.name = new_outer_input_var.name.lower() + + var_slices = [new_outer_input_var[b:e] for b, e in slice_seqs] + n_steps = tt.min([tt.shape(n)[0] for n in var_slices]) + + if output_scan_args.n_steps is None or replacing: + output_scan_args.n_steps = n_steps + + output_scan_args.outer_in_seqs += [v[:n_steps] for v in var_slices] + + if not replacing or add_nit_sot: + output_scan_args.outer_in_nit_sot += [n_steps] + + if inner_out_fn: + output_scan_args.inner_out_nit_sot += [ + inner_out_fn(input_scan_args, old_inner_out_var, inner_in_var, output_scan_args) + ] + + return output_scan_args, new_outer_input_var diff --git a/symbolic_pymc/theano/pymc3.py b/symbolic_pymc/theano/pymc3.py index b7d0c60..cfe7dfb 100644 --- a/symbolic_pymc/theano/pymc3.py +++ b/symbolic_pymc/theano/pymc3.py @@ -15,10 +15,13 @@ from multipledispatch import dispatch from unification.utils import transitive_get as walk +from theano.gof import Query +from theano.compile import optdb from theano.gof.op import get_test_value +from theano.gof.opt import SeqOptimizer, EquilibriumOptimizer from theano.gof.graph import Apply, inputs as tt_inputs from theano.scan_module.scan_op import Scan -from theano.scan_module.scan_utils import clone +from theano.scan_module.scan_utils import clone as tt_clone from .random_variables import ( observed, @@ -59,9 +62,15 @@ NegBinomialRV, NegBinomialRVType, ) -from .opt import FunctionGraph from .ops import RandomVariable -from .utils import replace_input_nodes, get_rv_observation +from .utils import ( + replace_input_nodes, + get_rv_observation, + optimize_graph, + get_random_outer_outputs, + construct_scan, +) +from .opt import FunctionGraph, push_out_rvs_from_scan, convert_outer_out_to_in, ScanArgs logger = logging.getLogger("symbolic_pymc") @@ -76,49 +85,6 @@ def tt_get_values(obj): raise TypeError(f"Unhandled observation type: {type(obj)}") -@dispatch(Apply, object) -def convert_rv_to_dist(node, obs): - if not isinstance(node.op, RandomVariable): - raise TypeError(f"{node} is not of type `RandomVariable`") - - rv = node.default_output() - - if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"): - shape = list(node.fgraph.shape_feature.shape_tuple(rv)) - else: - shape = list(rv.shape) - - for i, s in enumerate(shape): - try: - shape[i] = tt.get_scalar_constant_value(s) - except tt.NotScalarConstantError: - shape[i] = s.tag.test_value - - dist_type, dist_params = _convert_rv_to_dist(node.op, node) - return dist_type(rv.name, shape=shape, observed=obs, **dist_params) - - -@dispatch(tt.TensorVariable, object) -def logp(var, obs): - - node = var.owner - - if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"): - shape = list(node.fgraph.shape_feature.shape_tuple(var)) - else: - shape = list(var.shape) - - for i, s in enumerate(shape): - try: - shape[i] = tt.get_scalar_constant_value(s) - except tt.NotScalarConstantError: - shape[i] = s.tag.test_value - - logp_fn = _logp_fn(node.op, node, shape) - return logp_fn(obs) - - -@dispatch(RandomVariable, Apply, object) def _logp_fn(op, node, shape=None): dist_type, dist_params = _convert_rv_to_dist(op, node) if shape is not None: @@ -133,85 +99,129 @@ def _logp_fn(op, node, shape=None): return res.logp -@_logp_fn.register(Scan, Apply, object) -def _logp_fn_Scan(op, scan_node, shape=None): +def create_inner_out_logp(input_scan_args, old_inner_out_var, new_inner_in_var, output_scan_args): + """Create a log-likelihood inner-output for a `Scan`.""" + + # shape = list(old_inner_out_var.owner.fgraph.shape_feature.shape_tuple(old_inner_out_var)) + shape = None + logp_fn = _logp_fn(old_inner_out_var.owner.op, old_inner_out_var.owner, shape) + logp = logp_fn(new_inner_in_var) + if new_inner_in_var.name: + logp.name = "logp({})".format(new_inner_in_var.name) + return logp - scan_inner_inputs = scan_node.op.inputs - scan_inner_outputs = scan_node.op.outputs - def create_obs_var(i, x): - obs = x.type() - obs.name = f"{x.name or x.owner.op.name}_obs_{i}" - if hasattr(x.tag, "test_value"): - obs.tag.test_value = x.tag.test_value - return obs +def logp(*output_vars): + """Compute the log-likelihood for a graph. - rv_outs = [ - (i, x, create_obs_var(i, x)) - for i, x in enumerate(scan_inner_outputs) - if x.owner and isinstance(x.owner.op, RandomVariable) - ] - rv_inner_out_idx, rv_out_vars, rv_out_obs = zip(*rv_outs) - # rv_outer_out_idx = [scan_node.op.get_oinp_iinp_iout_oout_mappings()['outer_out_from_inner_out'][i] for i in rv_inner_out_idx] - # rv_outer_outputs = [scan_node.outputs[i] for i in rv_outer_out_idx] + Parameters + ---------- + *output_vars: Tuple[TensorVariable] + The output of a graph containing `RandomVariable`s. - logp_inner_outputs = [clone(logp(rv, obs)) for i, rv, obs in rv_outs] - assert all(o in tt.gof.graph.inputs(logp_inner_outputs) for o in rv_out_obs) + Results + ------- + Dict[TensorVariable, TensorVariable] + A map from `RandomVariable`s to their log-likelihood graphs. - logp_inner_outputs_inputs = tt.gof.graph.inputs(logp_inner_outputs) - rv_relevant_inner_input_idx, rv_relevant_inner_inputs = zip( - *[(n, i) for n, i in enumerate(scan_inner_inputs) if i in logp_inner_outputs_inputs] + """ + # model_inputs = [i for i in tt_inputs(output_vars) if not isinstance(i, tt.Constant)] + model_inputs = tt_inputs(output_vars) + model_fgraph = FunctionGraph( + model_inputs, + output_vars, + clone=True, + # XXX: `ShapeFeature` introduces cached constants + # features=[tt.opt.ShapeFeature()] ) - logp_inner_inputs = list(rv_out_obs) + list(rv_relevant_inner_inputs) - - # We need to create outer-inputs that represent arrays of observations - # for each random variable. - # To do that, we're going to use each random variable's outer-output term, - # since they necessarily have the same shape and type as the observations - # arrays. - - # Just like we did for the inner-inputs, we need to get only the outer-inputs - # that are relevant to the new logp graphs. - # We can do that by removing the irrelevant outer-inputs using the known relevant inner-inputs - removed_inner_inputs = set(range(len(scan_inner_inputs))) - set(rv_relevant_inner_input_idx) - old_in_out_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings() - rv_removed_outer_input_idx = [ - old_in_out_mappings["outer_inp_from_inner_inp"][i] for i in removed_inner_inputs - ] - rv_removed_outer_inputs = [scan_node.inputs[i] for i in rv_removed_outer_input_idx] - - rv_relevant_outer_inputs = [r for r in scan_node.inputs if r not in rv_removed_outer_inputs] - - # Now, we can create a new op with our new inner-graph inputs and outputs. - # Also, since our inner graph has new placeholder terms representing - # an observed value for each random variable, we need to update the - # "info" `dict`. - logp_info = scan_node.op.info.copy() - logp_info["tap_array"] = [] - logp_info["n_seqs"] += len(rv_out_obs) - logp_info["n_mit_mot"] = 0 - logp_info["n_mit_mot_outs"] = 0 - logp_info["mit_mot_out_slices"] = [] - logp_info["n_mit_sot"] = 0 - logp_info["n_sit_sot"] = 0 - logp_info["n_shared_outs"] = 0 - logp_info["n_nit_sot"] += len(rv_out_obs) - 1 - logp_info["name"] = None - logp_info["strict"] = True - - # These are the tensor variables corresponding to each random variable's - # array of observations. - def logp_fn(*obs): - logp_obs_outer_inputs = list(obs) # [r.clone() for r in rv_outer_outputs] - logp_outer_inputs = ( - [rv_relevant_outer_inputs[0]] + logp_obs_outer_inputs + rv_relevant_outer_inputs[1:] - ) - logp_op = Scan(logp_inner_inputs, logp_inner_outputs, logp_info) - scan_logp = logp_op(*logp_outer_inputs) - return scan_logp - # logp_fn = OpFromGraph(logp_obs_outer_inputs, [scan_logp]) - return logp_fn + canonicalize_opt = optdb.query(Query(include=["canonicalize"])) + push_out_opt = EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10) + optimizations = SeqOptimizer(canonicalize_opt.copy()) + optimizations.append(push_out_opt) + opt_fgraph = optimize_graph(model_fgraph, optimizations, in_place=True) + + replacements = {} + rv_to_logp_io = {} + for node in opt_fgraph.toposort(): + # TODO: This `RandomVariable` "parsing" should be generalized and used + # in more places (e.g. what if the outer-outputs are `Subtensor`s) + if isinstance(node.op, RandomVariable): + var = node.default_output() + # shape = list(node.fgraph.shape_feature.shape_tuple(new_var)) + shape = None + new_input_var = var.clone() + if new_input_var.name: + new_input_var.name = new_input_var.name.lower() + replacements[var] = new_input_var + rv_to_logp_io[var] = (new_input_var, _logp_fn(node.op, var.owner, shape)(new_input_var)) + + if isinstance(node.op, tt.Subtensor) and node.inputs[0].owner: + # The output of `theano.scan` is sometimes a sliced tensor (in + # order to get rid of initial values introduced by in the `Scan`) + node = node.inputs[0].owner + + if isinstance(node.op, Scan): + scan_args = ScanArgs.from_node(node) + rv_outer_outs = get_random_outer_outputs(scan_args) + + for var_idx, var, io_var in rv_outer_outs: + scan_args, new_oi_var = convert_outer_out_to_in( + scan_args, var, inner_out_fn=create_inner_out_logp, output_scan_args=scan_args + ) + replacements[var] = new_oi_var + + logp_scan_out = construct_scan(scan_args) + + for var_idx, var, io_var in rv_outer_outs: + rv_to_logp_io[var] = (replacements[var], logp_scan_out[var_idx]) + + # We need to use the new log-likelihood input variables that were generated + # for each `RandomVariable` node. They need to replace the corresponding + # original variables within each log-likelihood graph. + rv_vars, inputs_logp_outputs = zip(*rv_to_logp_io.items()) + new_inputs, logp_outputs = zip(*inputs_logp_outputs) + + rev_memo = {v: k for k, v in model_fgraph.memo.items()} + + # Replace the new cloned variables with the original ones, but only if + # they're not any of `RandomVariable` terms we've converted to + # log-likelihoods. + replacements.update( + { + k: v + for k, v in rev_memo.items() + if isinstance(k, tt.Variable) and v not in new_inputs and k not in replacements + } + ) + + new_logp_outputs = tt_clone(logp_outputs, replace=replacements) + + rv_to_logp_io = {rev_memo[k]: v for k, v in zip(rv_vars, zip(new_inputs, new_logp_outputs))} + + return rv_to_logp_io + + +@dispatch(Apply, object) +def convert_rv_to_dist(node, obs): + if not isinstance(node.op, RandomVariable): + raise TypeError(f"{node} is not of type `RandomVariable`") # pragma: no cover + + rv = node.default_output() + + if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"): + shape = list(node.fgraph.shape_feature.shape_tuple(rv)) + else: + shape = list(rv.shape) + + for i, s in enumerate(shape): + try: + shape[i] = tt.get_scalar_constant_value(s) + except tt.NotScalarConstantError: + shape[i] = get_test_value(s) + + dist_type, dist_params = _convert_rv_to_dist(node.op, node) + return dist_type(rv.name, shape=shape, observed=obs, **dist_params) @dispatch(pm.Uniform, object) @@ -230,27 +240,27 @@ def _convert_rv_to_dist(op, rv): @convert_dist_to_rv.register(pm.Normal, object) def convert_dist_to_rv_Normal(dist, rng): size = dist.shape.astype(int)[NormalRV.ndim_supp :] - res = NormalRV(dist.mu, dist.sd, size=size, rng=rng) + res = NormalRV(dist.mu, dist.sigma, size=size, rng=rng) return res @_convert_rv_to_dist.register(NormalRVType, Apply) def _convert_rv_to_dist_Normal(op, rv): - params = {"mu": rv.inputs[0], "sd": rv.inputs[1]} + params = {"mu": rv.inputs[0], "sigma": rv.inputs[1]} return pm.Normal, params @convert_dist_to_rv.register(pm.HalfNormal, object) def convert_dist_to_rv_HalfNormal(dist, rng): size = dist.shape.astype(int)[HalfNormalRV.ndim_supp :] - res = HalfNormalRV(np.array(0.0, dtype=dist.dtype), dist.sd, size=size, rng=rng) + res = HalfNormalRV(np.array(0.0, dtype=dist.dtype), dist.sigma, size=size, rng=rng) return res @_convert_rv_to_dist.register(HalfNormalRVType, Apply) def _convert_rv_to_dist_HalfNormal(op, rv): assert not np.any(tt_get_values(rv.inputs[0])) - params = {"sd": rv.inputs[1]} + params = {"sigma": rv.inputs[1]} return pm.HalfNormal, params diff --git a/symbolic_pymc/theano/utils.py b/symbolic_pymc/theano/utils.py index 4fd15b8..be50301 100644 --- a/symbolic_pymc/theano/utils.py +++ b/symbolic_pymc/theano/utils.py @@ -1,8 +1,9 @@ import theano.tensor as tt from theano.gof import FunctionGraph as tt_FunctionGraph, Query -from theano.gof.graph import inputs as tt_inputs, clone_get_equiv, io_toposort +from theano.gof.graph import inputs as tt_inputs, clone_get_equiv, io_toposort, ancestors from theano.compile import optdb +from theano.scan_module.scan_op import Scan from .meta import mt from .opt import FunctionGraph @@ -121,6 +122,7 @@ def optimize_graph(x, optimization, return_graph=None, in_place=False): res = x_graph_opt else: res = x_graph_opt.outputs + x_graph_opt.disown() if len(res) == 1: (res,) = res return res @@ -145,3 +147,71 @@ def get_rv_observation(node): if isinstance(o.op, Observed): return o return None + + +def is_random_variable(var): + """Check if a Theano `Apply` node is a random variable. + + Output + ------ + Tuple[TensorVariable, TensorVariable] + Returns a tuple with the `RandomVariable` or `Scan` `Op` containing a + `RandomVariable` variable--along with the corresponding output variable + that is a client of said `Op`; otherwise, `None`. + + """ + node = var.owner + + if not var.owner: + return None + + # Consider `Subtensor` `Op`s that slice a `Scan`. This is the type of + # output sometimes returned by `theano.scan` when taps/lags are used. + if isinstance(node.op, tt.Subtensor) and node.inputs[0].owner: + var = node.inputs[0] + node = var.owner + + if isinstance(node.op, RandomVariable): + return (var, var) + + if isinstance(node.op, Scan): + op = node.op + inner_out_var_idx = op.var_mappings["outer_out_from_inner_out"][node.outputs.index(var)] + inner_out_var = op.outputs[inner_out_var_idx] + + if isinstance(inner_out_var.owner.op, RandomVariable): + return (var, inner_out_var) + + return None + + +def vars_to_rvs(var): + """Compute paths from `TensorVariable`s to their underlying `RandomVariable` outputs.""" + return { + a: v if v[0] is not a else (v[1]) + for a, v in [(a, is_random_variable(a)) for a in ancestors([var])] + if v is not None + } + + +def get_random_outer_outputs(scan_args): + """Get the `RandomVariable` outputs of a `Scan` (well, it's `ScanArgs`).""" + rv_vars = [] + for n, oo in enumerate(scan_args.outer_outputs): + oo_info = scan_args.find_among_fields(oo) + io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :] + inner_out_type = "inner_out_{}".format(io_type) + io_var = getattr(scan_args, inner_out_type)[oo_info.index] + if io_var.owner and isinstance(io_var.owner.op, RandomVariable): + rv_vars.append((n, oo, io_var)) + return rv_vars + + +def construct_scan(scan_args): + scan_op = Scan(scan_args.inner_inputs, scan_args.inner_outputs, scan_args.info) + scan_out = scan_op(*scan_args.outer_inputs) + + if not isinstance(scan_out, list): + scan_out = [scan_out] + + return scan_out diff --git a/tests/theano/test_opt.py b/tests/theano/test_opt.py index e2d6173..38fe17f 100644 --- a/tests/theano/test_opt.py +++ b/tests/theano/test_opt.py @@ -1,5 +1,9 @@ +import pytest +import numpy as np +import theano import theano.tensor as tt +from copy import copy from unification import var from kanren import eq @@ -9,18 +13,26 @@ from theano.gof.opt import EquilibriumOptimizer from theano.gof.graph import inputs as tt_inputs +from theano.scan_module.scan_op import Scan from symbolic_pymc.theano.meta import mt -from symbolic_pymc.theano.opt import KanrenRelationSub, FunctionGraph -from symbolic_pymc.theano.utils import optimize_graph +from symbolic_pymc.theano.opt import ( + KanrenRelationSub, + FunctionGraph, + push_out_rvs_from_scan, + ScanArgs, + convert_outer_out_to_in, +) +from symbolic_pymc.theano.utils import optimize_graph, get_random_outer_outputs, construct_scan +from symbolic_pymc.theano.random_variables import CategoricalRV, DirichletRV, NormalRV +from tests.theano.utils import create_test_hmm + +@theano.change_flags(compute_test_value="ignore", cxx="", mode="FAST_COMPILE") def test_kanren_opt(): """Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached. """ - tt.config.cxx = "" - tt.config.compute_test_value = "ignore" - x_tt = tt.vector("x") c_tt = tt.vector("c") d_tt = tt.vector("c") @@ -58,3 +70,824 @@ def distributes(in_lv, out_lv): assert fgraph_opt.owner.inputs[1].owner.op == tt.add assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[0].owner.op, tt.Dot) assert isinstance(fgraph_opt.owner.inputs[1].owner.inputs[1].owner.op, tt.Dot) + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_push_out_rvs(): + + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + N_tt = tt.iscalar("N") + N_tt.tag.test_value = 10 + M_tt = tt.iscalar("M") + M_tt.tag.test_value = 2 + + mus_tt = tt.matrix("mus_t") + mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype( + theano.config.floatX + ) + + sigmas_tt = tt.ones((N_tt,)) + Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma") + + # The optimizer should do nothing to this term, because it's not a `Scan` + fgraph = FunctionGraph(tt_inputs([Gamma_rv]), [Gamma_rv]) + pushoutrvs_opt = EquilibriumOptimizer([push_out_rvs_from_scan], max_use_ratio=10) + Gamma_opt_rv = optimize_graph(fgraph, pushoutrvs_opt, return_graph=False) + # The `FunctionGraph` will, however, clone the graph objects, so we can't + # simply check that `gamma_opt_rv == Gamma_rv` + assert all(type(a) == type(b) for a, b in zip(tt_inputs([Gamma_rv]), tt_inputs([Gamma_opt_rv]))) + assert theano.scan_module.scan_utils.equal_computations( + [Gamma_opt_rv], [Gamma_rv], tt_inputs([Gamma_opt_rv]), tt_inputs([Gamma_rv]) + ) + + # In this case, `Y_t` depends on `S_t` and `S_t` is not output. Our + # push-out optimization should create a new `Scan` that also outputs each + # `S_t`. + def scan_fn(mus_t, sigma_t, Gamma_t, rng): + S_t = CategoricalRV(Gamma_t[0], rng=rng, name="S_t") + Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t") + return Y_t + + Y_rv, _ = theano.scan( + fn=scan_fn, + sequences=[mus_tt, sigmas_tt], + non_sequences=[Gamma_rv, rng_tt], + outputs_info=[{}], + strict=True, + name="scan_rv", + ) + Y_rv.name = "Y_rv" + + orig_scan_op = Y_rv.owner.op + assert len(Y_rv.owner.outputs) == 2 + assert isinstance(orig_scan_op, Scan) + assert len(orig_scan_op.outputs) == 2 + assert orig_scan_op.outputs[0].owner.op == NormalRV + assert isinstance(orig_scan_op.outputs[1].type, tt.raw_random.RandomStateType) + + fgraph = FunctionGraph(tt_inputs([Y_rv]), [Y_rv], clone=True) + fgraph_opt = optimize_graph(fgraph, pushoutrvs_opt, return_graph=True) + + # There should now be a new output for all the `S_t` + new_scan = fgraph_opt.outputs[0].owner + assert len(new_scan.outputs) == 3 + assert isinstance(new_scan.op, Scan) + assert new_scan.op.outputs[0].owner.op == NormalRV + assert new_scan.op.outputs[1].owner.op == CategoricalRV + assert isinstance(new_scan.op.outputs[2].type, tt.raw_random.RandomStateType) + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_basics(): + + # Make sure we can create an empty `ScanArgs` + scan_args = ScanArgs.create_empty() + assert scan_args.n_steps is None + for name in scan_args.field_names: + if name == "n_steps": + continue + assert len(getattr(scan_args, name)) == 0 + + with pytest.raises(TypeError): + ScanArgs.from_node(tt.ones(2).owner) + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + + # Make sure we can get alternate variables + test_v = scan_args.outer_out_sit_sot[0] + alt_test_v = scan_args.get_alt_field(test_v, "inner_out") + assert alt_test_v == scan_args.inner_out_sit_sot[0] + + # Check the `__repr__` and `__str__` + scan_args_repr = repr(scan_args) + # Just make sure it doesn't err-out + assert scan_args_repr.startswith("ScanArgs") + + # Check the properties that allow us to use + # `Scan.get_oinp_iinp_iout_oout_mappings` as-is to implement + # `ScanArgs.var_mappings` + assert scan_args.n_nit_sot == scan_op.n_nit_sot + assert scan_args.n_mit_mot == scan_op.n_mit_mot + # The `scan_args` base class always clones the inner-graph; + # here we make sure it doesn't (and that all the inputs are the same) + assert scan_args.inputs == scan_op.inputs + scan_op_info = dict(scan_op.info) + # The `ScanInfo` dictionary has the wrong order and an extra entry + del scan_op_info["strict"] + assert dict(scan_args.info) == scan_op_info + assert scan_args.var_mappings == scan_op.var_mappings + + # Check that `ScanArgs.find_among_fields` works + test_v = scan_op.inner_seqs(scan_op.inputs)[1] + field_info = scan_args.find_among_fields(test_v) + assert field_info.name == "inner_in_seqs" + assert field_info.index == 1 + assert field_info.inner_index is None + assert scan_args.inner_inputs[field_info.agg_index] == test_v + + test_l = scan_op.inner_non_seqs(scan_op.inputs) + # We didn't index this argument, so it's a `list` (i.e. bad input) + field_info = scan_args.find_among_fields(test_l) + assert field_info is None + + test_v = test_l[0] + field_info = scan_args.find_among_fields(test_v) + assert field_info.name == "inner_in_non_seqs" + assert field_info.index == 0 + assert field_info.inner_index is None + assert scan_args.inner_inputs[field_info.agg_index] == test_v + + scan_args_copy = copy(scan_args) + assert scan_args_copy is not scan_args + assert scan_args_copy == scan_args + + assert scan_args_copy != test_v + scan_args_copy.outer_in_seqs.pop() + assert scan_args_copy != scan_args + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_basics_mit_sot(): + + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + N_tt = tt.iscalar("N") + N_tt.tag.test_value = 10 + M_tt = tt.iscalar("M") + M_tt.tag.test_value = 2 + + mus_tt = tt.matrix("mus") + mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype( + theano.config.floatX + ) + + sigmas_tt = tt.ones((N_tt,)) + sigmas_tt.name = "sigmas" + + pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0") + Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma") + + S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0") + + def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t, rng): + S_t = CategoricalRV(Gamma_t[S_tm2], rng=rng, name="S_t") + Y_t = NormalRV(mus_t[S_tm1], sigma_t, rng=rng, name="Y_t") + return S_t, Y_t + + (S_rv, Y_rv), scan_updates = theano.scan( + fn=scan_fn, + sequences=[mus_tt, sigmas_tt], + non_sequences=[Gamma_rv, rng_tt], + outputs_info=[{"initial": tt.stack([S_0_rv, S_0_rv]), "taps": [-2, -1]}, {}], + strict=True, + name="scan_rv", + ) + # Adding names should make output easier to read + Y_rv.name = "Y_rv" + # This `S_rv` outer-output is actually a `Subtensor` of the "real" output + S_rv = S_rv.owner.inputs[0] + S_rv.name = "S_rv" + rng_updates = scan_updates[rng_tt] + rng_updates.name = "rng_updates" + mus_in = Y_rv.owner.inputs[1] + mus_in.name = "mus_in" + sigmas_in = Y_rv.owner.inputs[2] + sigmas_in.name = "sigmas_in" + + scan_args = ScanArgs.from_node(Y_rv.owner) + + test_v = scan_args.inner_in_mit_sot[0][1] + field_info = scan_args.find_among_fields(test_v) + + assert field_info.name == "inner_in_mit_sot" + assert field_info.index == 0 + assert field_info.inner_index == 1 + assert field_info.agg_index == 3 + + rm_info = scan_args._remove_from_fields(tt.ones(2)) + assert rm_info is None + + rm_info = scan_args._remove_from_fields(test_v) + + assert rm_info.name == "inner_in_mit_sot" + assert rm_info.index == 0 + assert rm_info.inner_index == 1 + assert rm_info.agg_index == 3 + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_inner_input(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_rv = hmm_model_env["S_rv"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Check `ScanArgs.remove_from_fields` by removing `sigmas[t]` (i.e. the + # inner-graph input) + scan_args_copy = copy(scan_args) + test_v = sigmas_t + + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=False) + removed_nodes, _ = zip(*rm_info) + + assert sigmas_t in removed_nodes + assert sigmas_t not in scan_args_copy.inner_in_seqs + assert Y_t not in removed_nodes + assert len(scan_args_copy.inner_out_nit_sot) == 1 + + scan_args_copy = copy(scan_args) + test_v = sigmas_t + + # This removal includes dependents + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + # `sigmas[t]` (i.e. inner-graph input) should be gone + assert sigmas_t in removed_nodes + assert sigmas_t not in scan_args_copy.inner_in_seqs + # `Y_t` (i.e. inner-graph output) should be gone + assert Y_t in removed_nodes + assert len(scan_args_copy.inner_out_nit_sot) == 0 + # `Y_rv` (i.e. outer-graph output) should be gone + assert Y_rv in removed_nodes + assert Y_rv not in scan_args_copy.outer_outputs + assert len(scan_args_copy.outer_out_nit_sot) == 0 + # `sigmas_in` (i.e. outer-graph input) should be gone + assert sigmas_in in removed_nodes + assert test_v not in scan_args_copy.inner_in_seqs + + # These shouldn't have been removed + assert S_t in scan_args_copy.inner_out_sit_sot + assert S_in in scan_args_copy.outer_out_sit_sot + assert Gamma_in in scan_args_copy.inner_in_non_seqs + assert Gamma_rv in scan_args_copy.outer_in_non_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + # The other `Y_rv`-related inputs currently aren't removed, even though + # they're no longer needed. + # TODO: Would be nice if we did this, too + # assert len(scan_args_copy.outer_in_seqs) == 0 + # TODO: Would be nice if we did this, too + # assert len(scan_args_copy.inner_in_seqs) == 0 + + # We shouldn't be able to remove the removed node + with pytest.raises(ValueError): + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_outer_input(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_rv = hmm_model_env["S_rv"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `sigmas` (i.e. the outer-input) + scan_args_copy = copy(scan_args) + test_v = sigmas_in + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + # `sigmas_in` (i.e. outer-graph input) should be gone + assert scan_args.outer_in_seqs[-1] in removed_nodes + assert test_v not in scan_args_copy.inner_in_seqs + + # `sigmas[t]` should be gone + assert sigmas_t in removed_nodes + assert sigmas_t not in scan_args_copy.inner_in_seqs + + # `Y_t` (i.e. inner-graph output) should be gone + assert Y_t in removed_nodes + assert len(scan_args_copy.inner_out_nit_sot) == 0 + + # `Y_rv` (i.e. outer-graph output) should be gone + assert Y_rv not in scan_args_copy.outer_outputs + assert len(scan_args_copy.outer_out_nit_sot) == 0 + + assert S_t in scan_args_copy.inner_out_sit_sot + assert S_in in scan_args_copy.outer_out_sit_sot + assert Gamma_in in scan_args_copy.inner_in_non_seqs + assert Gamma_rv in scan_args_copy.outer_in_non_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_inner_output(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_rv = hmm_model_env["S_rv"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `Y_t` (i.e. the inner-output) + scan_args_copy = copy(scan_args) + test_v = Y_t + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + # `Y_t` (i.e. inner-graph output) should be gone + assert Y_t in removed_nodes + assert len(scan_args_copy.inner_out_nit_sot) == 0 + + # `Y_rv` (i.e. outer-graph output) should be gone + assert Y_rv not in scan_args_copy.outer_outputs + assert len(scan_args_copy.outer_out_nit_sot) == 0 + + assert S_t in scan_args_copy.inner_out_sit_sot + assert S_in in scan_args_copy.outer_out_sit_sot + assert Gamma_in in scan_args_copy.inner_in_non_seqs + assert Gamma_rv in scan_args_copy.outer_in_non_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_outer_output(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `Y_rv` (i.e. a nit-sot outer-output) + scan_args_copy = copy(scan_args) + test_v = Y_rv + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + # `Y_t` (i.e. inner-graph output) should be gone + assert Y_t in removed_nodes + assert len(scan_args_copy.inner_out_nit_sot) == 0 + + # `Y_rv` (i.e. outer-graph output) should be gone + assert Y_rv not in scan_args_copy.outer_outputs + assert len(scan_args_copy.outer_out_nit_sot) == 0 + + assert S_t in scan_args_copy.inner_out_sit_sot + assert S_in in scan_args_copy.outer_out_sit_sot + assert Gamma_in in scan_args_copy.inner_in_non_seqs + assert Gamma_rv in scan_args_copy.outer_in_non_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_nonseq_outer_input(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + mus_in = hmm_model_env["mus_in"] + mus_t = hmm_model_env["mus_t"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `Gamma` (i.e. a non-sequence outer-input) + scan_args_copy = copy(scan_args) + test_v = Gamma_rv + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + assert Gamma_rv in removed_nodes + assert Gamma_in in removed_nodes + assert S_in in removed_nodes + assert S_t in removed_nodes + assert Y_t in removed_nodes + assert Y_rv in removed_nodes + + assert mus_in in scan_args_copy.outer_in_seqs + assert sigmas_in in scan_args_copy.outer_in_seqs + assert mus_t in scan_args_copy.inner_in_seqs + assert sigmas_t in scan_args_copy.inner_in_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_nonseq_inner_input(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + mus_in = hmm_model_env["mus_in"] + mus_t = hmm_model_env["mus_t"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `Gamma` (i.e. a non-sequence inner-input) + scan_args_copy = copy(scan_args) + test_v = Gamma_in + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + assert Gamma_in in removed_nodes + assert Gamma_rv in removed_nodes + assert S_in in removed_nodes + assert S_t in removed_nodes + + assert mus_in in scan_args_copy.outer_in_seqs + assert sigmas_in in scan_args_copy.outer_in_seqs + assert mus_t in scan_args_copy.inner_in_seqs + assert sigmas_t in scan_args_copy.inner_in_seqs + assert rng_tt in scan_args_copy.outer_in_shared + assert rng_in in scan_args_copy.inner_out_shared + assert rng_updates in scan_args.outer_out_shared + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_ScanArgs_remove_shared_inner_output(): + + hmm_model_env = create_test_hmm() + scan_args = hmm_model_env["scan_args"] + scan_op = hmm_model_env["scan_op"] + Y_t = hmm_model_env["Y_t"] + Y_rv = hmm_model_env["Y_rv"] + mus_in = hmm_model_env["mus_in"] + mus_t = hmm_model_env["mus_t"] + sigmas_in = hmm_model_env["sigmas_in"] + sigmas_t = hmm_model_env["sigmas_t"] + Gamma_rv = hmm_model_env["Gamma_rv"] + Gamma_in = hmm_model_env["Gamma_in"] + S_in = hmm_model_env["S_in"] + S_t = hmm_model_env["S_t"] + rng_tt = hmm_model_env["rng_tt"] + rng_in = hmm_model_env["rng_in"] + rng_updates = hmm_model_env["rng_updates"] + + # Remove `rng` (i.e. a shared inner-output) + scan_args_copy = copy(scan_args) + test_v = rng_updates + rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True) + removed_nodes, _ = zip(*rm_info) + + assert rng_tt in removed_nodes + assert rng_in in removed_nodes + assert rng_updates in removed_nodes + assert Y_rv in removed_nodes + assert S_in in removed_nodes + + assert sigmas_in in scan_args_copy.outer_in_seqs + assert sigmas_t in scan_args_copy.inner_in_seqs + assert mus_in in scan_args_copy.outer_in_seqs + assert mus_t in scan_args_copy.inner_in_seqs + + +def create_inner_out_logp(input_scan_args, old_inner_out_var, new_inner_in_var, output_scan_args): + """Create a log-likelihood inner-output. + + This is intended to be use with `get_random_outer_outputs`. + + """ + from symbolic_pymc.theano.pymc3 import _logp_fn + + logp_fn = _logp_fn(old_inner_out_var.owner.op, old_inner_out_var.owner, None) + logp = logp_fn(new_inner_in_var) + if new_inner_in_var.name: + logp.name = "logp({})".format(new_inner_in_var.name) + return logp + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_convert_outer_out_to_in(): + + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + # + # We create a `Scan` representing a time-series model with normally + # distributed responses that are dependent on lagged values of both the + # response `RandomVariable` and a lagged "deterministic" that also depends + # on the lagged response values. + # + def input_step_fn(mu_tm1, y_tm1, rng): + mu_tm1.name = "mu_tm1" + y_tm1.name = "y_tm1" + mu = mu_tm1 + y_tm1 + 1 + mu.name = "mu_t" + return mu, NormalRV(mu, 1.0, rng=rng, name="Y_t") + + (mu_tt, Y_rv), _ = theano.scan( + fn=input_step_fn, + outputs_info=[ + {"initial": tt.as_tensor_variable(np.r_[0.0]), "taps": [-1]}, + {"initial": tt.as_tensor_variable(np.r_[0.0]), "taps": [-1]}, + ], + non_sequences=[rng_tt], + n_steps=10, + ) + + mu_tt.name = "mu_tt" + mu_tt.owner.inputs[0].name = "mu_all" + Y_rv.name = "Y_rv" + Y_rv.owner.inputs[0].name = "Y_all" + + input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner) + + # + # Sample from the model and create another `Scan` that computes the + # log-likelihood of the model at the sampled point. + # + Y_obs = tt.as_tensor_variable(Y_rv.eval()) + Y_obs.name = "Y_obs" + + def output_step_fn(y_t, y_tm1, mu_tm1): + import pymc3 as pm + + mu_tm1.name = "mu_tm1" + y_tm1.name = "y_tm1" + mu = mu_tm1 + y_tm1 + 1 + mu.name = "mu_t" + logp = pm.Normal.dist(mu, 1.0).logp(y_t) + logp.name = "logp" + return mu, logp + + (mu_tt, Y_logp), _ = theano.scan( + fn=output_step_fn, + sequences=[{"input": Y_obs, "taps": [0, -1]}], + outputs_info=[{"initial": tt.as_tensor_variable(np.r_[0.0]), "taps": [-1]}, {}], + ) + + Y_logp.name = "Y_logp" + mu_tt.name = "mu_tt" + + # output_scan_args = ScanArgs.from_node(Y_logp.owner) + + # + # Get the model output variable that corresponds to the response + # `RandomVariable` + # + var_idx, var, io_var = get_random_outer_outputs(input_scan_args)[0] + + # + # Convert the original model `Scan` into another `Scan` that's equivalent + # to the log-likelihood `Scan` given above. + # In other words, automatically construct the log-likelihood `Scan` based + # on the model `Scan`. + # + test_scan_args, new_oi_var = convert_outer_out_to_in( + input_scan_args, var, inner_out_fn=create_inner_out_logp, output_scan_args=input_scan_args + ) + + scan_out = construct_scan(test_scan_args) + + # + # Evaluate the manually and automatically constructed log-likelihoods and + # compare. + # + res = scan_out[var_idx].eval({new_oi_var: Y_obs.value}) + exp_res = Y_logp.eval() + + assert np.array_equal(res, exp_res) + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_convert_outer_out_to_in_mit_sot(): + + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + # + # This is a very simple model with only one output, but multiple + # taps/lags. + # + def input_step_fn(y_tm1, y_tm2, rng): + y_tm1.name = "y_tm1" + y_tm2.name = "y_tm2" + return NormalRV(y_tm1 + y_tm2, 1.0, rng=rng, name="Y_t") + + Y_rv, _ = theano.scan( + fn=input_step_fn, + outputs_info=[{"initial": tt.as_tensor_variable(np.r_[-1.0, 0.0]), "taps": [-1, -2]},], + non_sequences=[rng_tt], + n_steps=10, + ) + + Y_rv.name = "Y_rv" + Y_rv.owner.inputs[0].name = "Y_all" + + Y_obs = tt.as_tensor_variable(Y_rv.eval()) + Y_obs.name = "Y_obs" + + input_scan_args = ScanArgs.from_node(Y_rv.owner.inputs[0].owner) + + # + # The corresponding log-likelihood + # + def output_step_fn(y_t, y_tm1, y_tm2): + import pymc3 as pm + + y_t.name = "y_t" + y_tm1.name = "y_tm1" + y_tm2.name = "y_tm2" + logp = pm.Normal.dist(y_tm1 + y_tm2, 1.0).logp(y_t) + logp.name = "logp(y_t)" + return logp + + Y_logp, _ = theano.scan( + fn=output_step_fn, sequences=[{"input": Y_obs, "taps": [0, -1, -2]}], outputs_info=[{}] + ) + + # output_scan_args = ScanArgs.from_node(Y_logp.owner) + + # + # Get the model output variable that corresponds to the response + # `RandomVariable` + # + var_idx, var, io_var = get_random_outer_outputs(input_scan_args)[0] + + # + # Convert the original model `Scan` into another `Scan` that's equivalent + # to the log-likelihood `Scan` given above. + # In other words, automatically construct the log-likelihood `Scan` based + # on the model `Scan`. + # + # In this case, we perform the conversion on a "blank" `ScanArgs`. + # + test_scan_args, new_oi_var = convert_outer_out_to_in( + input_scan_args, var, inner_out_fn=create_inner_out_logp, output_scan_args=None + ) + + scan_out = construct_scan(test_scan_args) + + # + # Evaluate the manually and automatically constructed log-likelihoods and + # compare. + # + res = scan_out[var_idx].eval({new_oi_var: Y_obs.value}) + exp_res = Y_logp.eval() + + assert np.array_equal(res, exp_res) + + # + # Now, we rerun the test, but use the "replace" features of + # `convert_outer_out_to_in` + # + test_scan_args, new_oi_var = convert_outer_out_to_in( + input_scan_args, var, inner_out_fn=create_inner_out_logp, output_scan_args=input_scan_args + ) + + scan_out = construct_scan(test_scan_args) + + # + # Evaluate the manually and automatically constructed log-likelihoods and + # compare. + # + res = scan_out[var_idx].eval({new_oi_var: Y_obs.value}) + exp_res = Y_logp.eval() + + assert np.array_equal(res, exp_res) + + +@theano.change_flags(compute_test_value="warn", cxx="", mode="FAST_COMPILE") +def test_convert_outer_out_to_in_hmm(): + hmm_model_env = create_test_hmm() + input_scan_args = hmm_model_env["scan_args"] + M_tt = hmm_model_env["M_tt"] + N_tt = hmm_model_env["N_tt"] + mus_tt = hmm_model_env["mus_tt"] + sigmas_tt = hmm_model_env["sigmas_tt"] + Y_rv = hmm_model_env["Y_rv"] + S_0_rv = hmm_model_env["S_0_rv"] + Gamma_rv = hmm_model_env["Gamma_rv"] + rng_tt = hmm_model_env["rng_tt"] + rng_init_state = hmm_model_env["rng_init_state"] + + test_point = { + M_tt: 2, + N_tt: 10, + mus_tt: mus_tt.tag.test_value, + } + Y_obs = tt.as_tensor_variable(Y_rv.eval(test_point)) + Y_obs.name = "Y_obs" + + def logp_scan_fn(y_t, mus_t, sigma_t, S_tm1, Gamma_t, rng): + import pymc3 as pm + + gamma_t = Gamma_t[S_tm1] + gamma_t.name = "gamma_t" + S_t = CategoricalRV(gamma_t, rng=rng, name="S_t") + mu_t = mus_t[S_t] + mu_t.name = "mu[S_t]" + Y_logp_t = pm.Normal.dist(mu_t, sigma_t).logp(y_t) + Y_logp_t.name = "logp(y_t)" + return S_t, Y_logp_t + + (S_rv, Y_logp), scan_updates = theano.scan( + fn=logp_scan_fn, + sequences=[Y_obs, mus_tt, sigmas_tt], + non_sequences=[Gamma_rv, rng_tt], + outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}], + strict=True, + name="scan_rv", + ) + Y_logp.name = "Y_logp" + + var_idx, var, io_var = get_random_outer_outputs(input_scan_args)[1] + + test_scan_args, new_oi_var = convert_outer_out_to_in( + input_scan_args, var, inner_out_fn=create_inner_out_logp, output_scan_args=input_scan_args + ) + + scan_out = construct_scan(test_scan_args) + test_Y_logp = scan_out[var_idx] + + # + # Evaluate the manually and automatically constructed log-likelihoods and + # compare. + # + new_test_point = dict(test_point) + new_test_point[new_oi_var] = Y_obs.value + + # We need to reset the RNG each time, because `S_t` is still a + # `RandomVariable` + rng_tt.get_value(borrow=True).set_state(rng_init_state) + with theano.change_flags(on_unused_input="warn"): + res = test_Y_logp.eval(new_test_point) + + rng_tt.get_value(borrow=True).set_state(rng_init_state) + exp_res = Y_logp.eval(test_point) + + assert np.array_equal(res, exp_res) diff --git a/tests/theano/test_pymc3.py b/tests/theano/test_pymc3.py index a9ac339..6bb8cea 100644 --- a/tests/theano/test_pymc3.py +++ b/tests/theano/test_pymc3.py @@ -16,15 +16,16 @@ from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, Observed, observed from symbolic_pymc.theano.ops import RandomVariable from symbolic_pymc.theano.opt import FunctionGraph -from symbolic_pymc.theano.pymc3 import model_graph, graph_model, logp -from symbolic_pymc.theano.utils import canonicalize +from symbolic_pymc.theano.pymc3 import model_graph, graph_model, logp, convert_rv_to_dist +from symbolic_pymc.theano.utils import canonicalize, vars_to_rvs from symbolic_pymc.theano.meta import mt +from tests.theano.utils import create_test_hmm + +@theano.change_flags(compute_test_value="ignore", cxx="") def test_pymc3_convert_dists(): """Just a basic check that all PyMC3 RVs will convert to and from Theano RVs.""" - tt.config.compute_test_value = "ignore" - theano.config.cxx = "" with pm.Model() as model: norm_rv = pm.Normal("norm_rv", 0.0, 1.0, observed=1.0) @@ -67,9 +68,9 @@ def test_pymc3_convert_dists(): assert res.vars[0].name == "normal_0" +@theano.change_flags(compute_test_value="ignore") def test_pymc3_normal_model(): """Conduct a more in-depth test of PyMC3/Theano conversions for a specific model.""" - tt.config.compute_test_value = "ignore" mu_X = tt.dscalar("mu_X") sd_X = tt.dscalar("sd_X") @@ -80,10 +81,10 @@ def test_pymc3_normal_model(): # We need something that uses transforms... with pm.Model() as model: - X_rv = pm.Normal("X_rv", mu_X, sd=sd_X) + X_rv = pm.Normal("X_rv", mu_X, sigma=sd_X) S_rv = pm.HalfCauchy("S_rv", beta=np.array(0.5, dtype=tt.config.floatX)) - Y_rv = pm.Normal("Y_rv", X_rv * S_rv, sd=S_rv) - Z_rv = pm.Normal("Z_rv", X_rv + Y_rv, sd=sd_X, observed=10.0) + Y_rv = pm.Normal("Y_rv", X_rv * S_rv, sigma=S_rv) + Z_rv = pm.Normal("Z_rv", X_rv + Y_rv, sigma=sd_X, observed=10.0) fgraph = model_graph(model, output_vars=[Z_rv]) @@ -146,9 +147,23 @@ def test_pymc3_normal_model(): assert all(v == 1 for v in Z_vars_count.values()) +@theano.change_flags(compute_test_value="ignore") +def test_convert_rv_to_dist_shape(): + + # Make sure we use the `ShapeFeature` to get the shape info + X_rv = NormalRV(np.r_[1, 2], 2.0, name="X_rv") + fgraph = FunctionGraph(tt_inputs([X_rv]), [X_rv], features=[tt.opt.ShapeFeature()]) + + with pm.Model(): + res = convert_rv_to_dist(fgraph.outputs[0].owner, None) + + assert isinstance(res.distribution, pm.Normal) + assert np.array_equal(res.distribution.shape, np.r_[2]) + + +@theano.change_flags(compute_test_value="ignore") def test_normals_to_model(): """Test conversion to a PyMC3 model.""" - tt.config.compute_test_value = "ignore" a_tt = tt.vector("a") R_tt = tt.matrix("R") @@ -211,9 +226,9 @@ def _check_model(model): model = graph_model(Y_obs) +@theano.change_flags(compute_test_value="ignore") def test_pymc3_broadcastable(): """Test PyMC3 to Theano conversion amid array broadcasting.""" - tt.config.compute_test_value = "ignore" mu_X = tt.vector("mu_X") sd_X = tt.vector("sd_X") @@ -225,9 +240,9 @@ def test_pymc3_broadcastable(): sd_Y.tag.test_value = np.array([0.5], dtype=tt.config.floatX) with pm.Model() as model: - X_rv = pm.Normal("X_rv", mu_X, sd=sd_X, shape=(1,)) - Y_rv = pm.Normal("Y_rv", mu_Y, sd=sd_Y, shape=(1,)) - Z_rv = pm.Normal("Z_rv", X_rv + Y_rv, sd=sd_X + sd_Y, shape=(1,), observed=[10.0]) + X_rv = pm.Normal("X_rv", mu_X, sigma=sd_X, shape=(1,)) + Y_rv = pm.Normal("Y_rv", mu_Y, sigma=sd_Y, shape=(1,)) + Z_rv = pm.Normal("Z_rv", X_rv + Y_rv, sigma=sd_X + sd_Y, shape=(1,), observed=[10.0]) with pytest.warns(UserWarning): fgraph = model_graph(model) @@ -254,48 +269,95 @@ def test_pymc3_broadcastable(): assert mt(Z_rv_tt) == mt(Z_rv_meta) +@theano.change_flags(compute_test_value="warn", cxx="") def test_logp(): - test_rv = NormalRV(0, tt.arange(1, 3)) - test_logp = logp(test_rv, 0) - assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval()) + hmm_model_env = create_test_hmm() + M_tt = hmm_model_env["M_tt"] + N_tt = hmm_model_env["N_tt"] + mus_tt = hmm_model_env["mus_tt"] + sigmas_tt = hmm_model_env["sigmas_tt"] + Y_rv = hmm_model_env["Y_rv"] + S_rv = hmm_model_env["S_rv"] + S_in = hmm_model_env["S_in"] + Gamma_rv = hmm_model_env["Gamma_rv"] + rng_tt = hmm_model_env["rng_tt"] + + Y_obs = Y_rv.clone() + Y_obs.name = "Y_obs" + # `S_in` includes `S_0_rv` (and `pi_0_rv`), unlike `S_rv` + S_obs = S_in.clone() + S_obs.name = "S_obs" + Gamma_obs = Gamma_rv.clone() + Gamma_obs.name = "Gamma_obs" + + test_point = { + mus_tt: mus_tt.tag.test_value, + N_tt: N_tt.tag.test_value, + Gamma_obs: Gamma_rv.tag.test_value, + Y_obs: Y_rv.tag.test_value, + S_obs: S_in.tag.test_value, + } + + def logp_scan_fn(s_t, s_tm1, y_t, mus_t, sigma_t, Gamma_t): + gamma_t = Gamma_t[s_tm1] + log_s_t = pm.Categorical.dist(gamma_t).logp(s_t) + mu_t = mus_t[s_t] + log_y_t = pm.Normal.dist(mu_t, sigma_t).logp(y_t) + gamma_t.name = "gamma_t" + log_y_t.name = "logp(y_t)" + log_s_t.name = "logp(s_t)" + mu_t.name = "mu[S_t]" + return log_s_t, log_y_t + + (true_S_logp, true_Y_logp), scan_updates = theano.scan( + fn=logp_scan_fn, + sequences=[{"input": S_obs, "taps": [0, -1]}, Y_obs, mus_tt, sigmas_tt], + non_sequences=[Gamma_obs], + outputs_info=[{}, {}], + strict=True, + name="scan_rv", + ) - fgraph = FunctionGraph(tt_inputs([test_rv]), [test_rv], features=[tt.opt.ShapeFeature()]) - test_rv.owner.fgraph = fgraph - test_logp = logp(test_rv, 0) + # Make sure there are no `RandomVariable` nodes among our + # expected/true log-likelihood graph. + assert not vars_to_rvs(true_S_logp) + assert not vars_to_rvs(true_Y_logp) - assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval()) + true_S_logp_val = true_S_logp.eval(test_point) + true_Y_logp_val = true_Y_logp.eval(test_point) - rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) - rng_tt = theano.shared(rng_state, name="rng", borrow=True) - rng_tt.tag.is_rng = True - rng_tt.default_update = rng_tt + # + # Now, compute the log-likelihoods + # + logps = logp(Y_rv) - # TODO: Scan of univariate normals. - N_tt = tt.iscalar("N") - N_tt.tag.test_value = 10 + S_logp = logps[S_in][1] + Y_logp = logps[Y_rv][1] - mus_tt = tt.arange(N_tt) - mus_tt.tag.test_value + # from theano.printing import debugprint as tt_dprint - sigmas_tt = tt.ones((N_tt,)) - sigmas_tt.tag.test_value + # There shouldn't be any `RandomVariable`s here either + assert not vars_to_rvs(S_logp[1]) + assert not vars_to_rvs(Y_logp[1]) - def scan_fn(mu_t, sigma_t, rng): - # mix = np.stack([NormalRV(mu_t, sigma_t, rng=rng), GammaRV(mu_t**2 / sigma_t**2, mu_t / sigma_t)]) - return NormalRV(mu_t, sigma_t, rng=rng) + assert N_tt in tt_inputs([S_logp]) + assert mus_tt in tt_inputs([S_logp]) + assert logps[S_in][0] in tt_inputs([S_logp]) + assert logps[Y_rv][0] in tt_inputs([S_logp]) + assert logps[Gamma_rv][0] in tt_inputs([S_logp]) - scan_rv, _ = theano.scan( - fn=scan_fn, - sequences=[mus_tt, sigmas_tt], - non_sequences=[rng_tt], - outputs_info=[{},], - strict=True, - name="scan_rv", - ) + new_test_point = { + mus_tt: mus_tt.tag.test_value, + N_tt: N_tt.tag.test_value, + logps[Gamma_rv][0]: Gamma_rv.tag.test_value, + logps[Y_rv][0]: Y_rv.tag.test_value, + logps[S_in][0]: S_in.tag.test_value, + } - scan_logp = logp(scan_rv, tt.zeros((N_tt,))) - res = scan_logp.eval({N_tt: 10}) - exp_res = pm.Normal.dist(np.arange(10), np.ones(10)).logp(np.zeros(10)).eval() + with theano.change_flags(on_unused_input="warn"): + S_logp_val = S_logp.eval(new_test_point) + Y_logp_val = Y_logp.eval(new_test_point) - assert np.array_equal(res, exp_res) + assert np.array_equal(true_S_logp_val, S_logp_val) + assert np.array_equal(Y_logp_val, true_Y_logp_val) diff --git a/tests/theano/test_utils.py b/tests/theano/test_utils.py new file mode 100644 index 0000000..cf54db9 --- /dev/null +++ b/tests/theano/test_utils.py @@ -0,0 +1,21 @@ +import theano + +from symbolic_pymc.theano.utils import is_random_variable +from symbolic_pymc.theano.random_variables import NormalRV + + +@theano.change_flags(compute_test_value="ignore", cxx="") +def test_is_random_variable(): + + X_rv = NormalRV(0, 1) + res = is_random_variable(X_rv) + assert res == (X_rv, X_rv) + + def scan_fn(): + Y_t = NormalRV(0, 1, name="Y_t") + return Y_t + + Y_rv, scan_updates = theano.scan(fn=scan_fn, outputs_info=[{}], n_steps=10,) + + res = is_random_variable(Y_rv) + assert res == (Y_rv, Y_rv.owner.op.outputs[0]) diff --git a/tests/theano/utils.py b/tests/theano/utils.py new file mode 100644 index 0000000..19e4131 --- /dev/null +++ b/tests/theano/utils.py @@ -0,0 +1,71 @@ +import numpy as np +import theano +import theano.tensor as tt + +from symbolic_pymc.theano.opt import ScanArgs +from symbolic_pymc.theano.random_variables import CategoricalRV, DirichletRV, NormalRV + + +def create_test_hmm(): + rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234))) + rng_init_state = rng_state.get_state() + rng_tt = theano.shared(rng_state, name="rng", borrow=True) + rng_tt.tag.is_rng = True + rng_tt.default_update = rng_tt + + N_tt = tt.iscalar("N") + N_tt.tag.test_value = 10 + M_tt = tt.iscalar("M") + M_tt.tag.test_value = 2 + + mus_tt = tt.matrix("mus") + mus_tt.tag.test_value = np.stack([np.arange(0.0, 10), np.arange(0.0, -10, -1)], axis=-1).astype( + theano.config.floatX + ) + + sigmas_tt = tt.ones((N_tt,)) + sigmas_tt.name = "sigmas" + + pi_0_rv = DirichletRV(tt.ones((M_tt,)), rng=rng_tt, name="pi_0") + Gamma_rv = DirichletRV(tt.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma") + + S_0_rv = CategoricalRV(pi_0_rv, rng=rng_tt, name="S_0") + + def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng): + S_t = CategoricalRV(Gamma_t[S_tm1], rng=rng, name="S_t") + Y_t = NormalRV(mus_t[S_t], sigma_t, rng=rng, name="Y_t") + return S_t, Y_t + + (S_rv, Y_rv), scan_updates = theano.scan( + fn=scan_fn, + sequences=[mus_tt, sigmas_tt], + non_sequences=[Gamma_rv, rng_tt], + outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}], + strict=True, + name="scan_rv", + ) + Y_rv.name = "Y_rv" + + scan_op = Y_rv.owner.op + scan_args = ScanArgs.from_node(Y_rv.owner) + + Gamma_in = scan_args.inner_in_non_seqs[0] + Y_t = scan_args.inner_out_nit_sot[0] + mus_t = scan_args.inner_in_seqs[0] + sigmas_t = scan_args.inner_in_seqs[1] + S_t = scan_args.inner_out_sit_sot[0] + rng_in = scan_args.inner_out_shared[0] + + rng_updates = scan_updates[rng_tt] + rng_updates.name = "rng_updates" + mus_in = Y_rv.owner.inputs[1] + mus_in.name = "mus_in" + sigmas_in = Y_rv.owner.inputs[2] + sigmas_in.name = "sigmas_in" + + # The output `S_rv` is really `S_rv[1:]`, so we have to extract the actual + # `Scan` output: `S_rv`. + S_in = S_rv.owner.inputs[0] + S_in.name = "S_in" + + return locals()