Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions symbolic_pymc/theano/pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from theano.gof.op import get_test_value
from theano.gof.graph import Apply, inputs as tt_inputs
from theano.scan_module.scan_op import Scan
from theano.scan_module.scan_utils import clone

from .random_variables import (
observed,
Expand Down Expand Up @@ -96,6 +98,122 @@ def convert_rv_to_dist(node, obs):
return dist_type(rv.name, shape=shape, observed=obs, **dist_params)


@dispatch(tt.TensorVariable, object)
def logp(var, obs):

node = var.owner

if hasattr(node, "fgraph") and hasattr(node.fgraph, "shape_feature"):
shape = list(node.fgraph.shape_feature.shape_tuple(var))
else:
shape = list(var.shape)

for i, s in enumerate(shape):
try:
shape[i] = tt.get_scalar_constant_value(s)
except tt.NotScalarConstantError:
shape[i] = s.tag.test_value

logp_fn = _logp_fn(node.op, node, shape)
return logp_fn(obs)


@dispatch(RandomVariable, Apply, object)
def _logp_fn(op, node, shape=None):
dist_type, dist_params = _convert_rv_to_dist(op, node)
if shape is not None:
dist_params["shape"] = shape
res = dist_type.dist(**dist_params)
# Add extra information to the PyMC3 `Distribution` object
res.dist_params = dist_params
res.ndim_supp = op.ndim_supp
# TODO: Need to maintain the order of these so that they correspond with
# the `Distribution`'s parameters
res.ndims_params = op.ndims_params
return res.logp


@_logp_fn.register(Scan, Apply, object)
def _logp_fn_Scan(op, scan_node, shape=None):

scan_inner_inputs = scan_node.op.inputs
scan_inner_outputs = scan_node.op.outputs

def create_obs_var(i, x):
obs = x.type()
obs.name = f"{x.name or x.owner.op.name}_obs_{i}"
if hasattr(x.tag, "test_value"):
obs.tag.test_value = x.tag.test_value
return obs

rv_outs = [
(i, x, create_obs_var(i, x))
for i, x in enumerate(scan_inner_outputs)
if x.owner and isinstance(x.owner.op, RandomVariable)
]
rv_inner_out_idx, rv_out_vars, rv_out_obs = zip(*rv_outs)
# rv_outer_out_idx = [scan_node.op.get_oinp_iinp_iout_oout_mappings()['outer_out_from_inner_out'][i] for i in rv_inner_out_idx]
# rv_outer_outputs = [scan_node.outputs[i] for i in rv_outer_out_idx]

logp_inner_outputs = [clone(logp(rv, obs)) for i, rv, obs in rv_outs]
assert all(o in tt.gof.graph.inputs(logp_inner_outputs) for o in rv_out_obs)

logp_inner_outputs_inputs = tt.gof.graph.inputs(logp_inner_outputs)
rv_relevant_inner_input_idx, rv_relevant_inner_inputs = zip(
*[(n, i) for n, i in enumerate(scan_inner_inputs) if i in logp_inner_outputs_inputs]
)
logp_inner_inputs = list(rv_out_obs) + list(rv_relevant_inner_inputs)

# We need to create outer-inputs that represent arrays of observations
# for each random variable.
# To do that, we're going to use each random variable's outer-output term,
# since they necessarily have the same shape and type as the observations
# arrays.

# Just like we did for the inner-inputs, we need to get only the outer-inputs
# that are relevant to the new logp graphs.
# We can do that by removing the irrelevant outer-inputs using the known relevant inner-inputs
removed_inner_inputs = set(range(len(scan_inner_inputs))) - set(rv_relevant_inner_input_idx)
old_in_out_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
rv_removed_outer_input_idx = [
old_in_out_mappings["outer_inp_from_inner_inp"][i] for i in removed_inner_inputs
]
rv_removed_outer_inputs = [scan_node.inputs[i] for i in rv_removed_outer_input_idx]

rv_relevant_outer_inputs = [r for r in scan_node.inputs if r not in rv_removed_outer_inputs]

# Now, we can create a new op with our new inner-graph inputs and outputs.
# Also, since our inner graph has new placeholder terms representing
# an observed value for each random variable, we need to update the
# "info" `dict`.
logp_info = scan_node.op.info.copy()
logp_info["tap_array"] = []
logp_info["n_seqs"] += len(rv_out_obs)
logp_info["n_mit_mot"] = 0
logp_info["n_mit_mot_outs"] = 0
logp_info["mit_mot_out_slices"] = []
logp_info["n_mit_sot"] = 0
logp_info["n_sit_sot"] = 0
logp_info["n_shared_outs"] = 0
logp_info["n_nit_sot"] += len(rv_out_obs) - 1
logp_info["name"] = None
logp_info["strict"] = True

# These are the tensor variables corresponding to each random variable's
# array of observations.
def logp_fn(*obs):
logp_obs_outer_inputs = list(obs) # [r.clone() for r in rv_outer_outputs]
logp_outer_inputs = (
[rv_relevant_outer_inputs[0]] + logp_obs_outer_inputs + rv_relevant_outer_inputs[1:]
)
logp_op = Scan(logp_inner_inputs, logp_inner_outputs, logp_info)
scan_logp = logp_op(*logp_outer_inputs)
return scan_logp

# logp_fn = OpFromGraph(logp_obs_outer_inputs, [scan_logp])
return logp_fn


@dispatch(pm.Uniform, object)
def convert_dist_to_rv(dist, rng):
size = dist.shape.astype(int)[UniformRV.ndim_supp :]
Expand Down Expand Up @@ -467,7 +585,7 @@ def model_graph(pymc_model, output_vars=None, rand_state=None, attach_memo=True)
return model_fg


def graph_model(graph, *model_args, **model_kwargs):
def graph_model(graph, *model_args, generate_names=False, **model_kwargs):
"""Create a PyMC3 model from a Theano graph with `RandomVariable` nodes."""
model = pm.Model(*model_args, **model_kwargs)

Expand All @@ -478,13 +596,14 @@ def graph_model(graph, *model_args, **model_kwargs):
nodes = [n for n in fgraph.toposort() if isinstance(n.op, RandomVariable)]
rv_replacements = {}

node_id = 0

for node in nodes:

obs = get_rv_observation(node)

if obs is not None:
obs = obs.inputs[0]

obs = tt_get_values(obs)

old_rv_var = node.default_output()
Expand All @@ -500,6 +619,12 @@ def graph_model(graph, *model_args, **model_kwargs):
if op != node
)

if generate_names and rv_var.name is None:
node_name = "{}_{}".format(node.op.name, node_id)
# warn("Name {} generated for node {}.".format(node, node_name))
node_id += 1
rv_var.name = node_name

with model:
rv = convert_rv_to_dist(node, obs)

Expand Down
7 changes: 2 additions & 5 deletions symbolic_pymc/theano/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,8 @@ def make_node(self, val, rv=None):
The distribution from which `val` is assumed to be a sample value.
"""
val = tt.as_tensor_variable(val)
if rv:
if rv.owner and not isinstance(rv.owner.op, RandomVariable):
raise ValueError(f"`rv` must be a RandomVariable type: {rv}")

if rv.type.convert_variable(val) is None:
if rv is not None:
if not hasattr(rv, "type") or rv.type.convert_variable(val) is None:
raise ValueError(
("`rv` and `val` do not have compatible types:" f" rv={rv}, val={val}")
)
Expand Down
57 changes: 55 additions & 2 deletions tests/theano/test_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# from theano.configparser import change_flags
from theano.gof.graph import inputs as tt_inputs

from symbolic_pymc.theano.random_variables import MvNormalRV, Observed, observed
from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, Observed, observed
from symbolic_pymc.theano.ops import RandomVariable
from symbolic_pymc.theano.opt import FunctionGraph
from symbolic_pymc.theano.pymc3 import model_graph, graph_model
from symbolic_pymc.theano.pymc3 import model_graph, graph_model, logp
from symbolic_pymc.theano.utils import canonicalize
from symbolic_pymc.theano.meta import mt

Expand Down Expand Up @@ -60,6 +60,12 @@ def test_pymc3_convert_dists():
new_pymc_rv_names = {n.name for n in pymc_model.observed_RVs}
pymc_rv_names == new_pymc_rv_names

with pytest.raises(TypeError):
graph_model(NormalRV(0, 1), generate_names=False)

res = graph_model(NormalRV(0, 1), generate_names=True)
assert res.vars[0].name == "normal_0"


def test_pymc3_normal_model():
"""Conduct a more in-depth test of PyMC3/Theano conversions for a specific model."""
Expand Down Expand Up @@ -246,3 +252,50 @@ def test_pymc3_broadcastable():
Z_rv_meta = canonicalize(Z_rv_obs_.reify(), return_graph=False)

assert mt(Z_rv_tt) == mt(Z_rv_meta)


def test_logp():
test_rv = NormalRV(0, tt.arange(1, 3))
test_logp = logp(test_rv, 0)

assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval())

fgraph = FunctionGraph(tt_inputs([test_rv]), [test_rv], features=[tt.opt.ShapeFeature()])
test_rv.owner.fgraph = fgraph
test_logp = logp(test_rv, 0)

assert np.all(test_logp.eval() == pm.Normal.dist(0, np.arange(1, 3)).logp(0).eval())

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_tt = theano.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt

# TODO: Scan of univariate normals.
N_tt = tt.iscalar("N")
N_tt.tag.test_value = 10

mus_tt = tt.arange(N_tt)
mus_tt.tag.test_value

sigmas_tt = tt.ones((N_tt,))
sigmas_tt.tag.test_value

def scan_fn(mu_t, sigma_t, rng):
# mix = np.stack([NormalRV(mu_t, sigma_t, rng=rng), GammaRV(mu_t**2 / sigma_t**2, mu_t / sigma_t)])
return NormalRV(mu_t, sigma_t, rng=rng)

scan_rv, _ = theano.scan(
fn=scan_fn,
sequences=[mus_tt, sigmas_tt],
non_sequences=[rng_tt],
outputs_info=[{},],
strict=True,
name="scan_rv",
)

scan_logp = logp(scan_rv, tt.zeros((N_tt,)))
res = scan_logp.eval({N_tt: 10})
exp_res = pm.Normal.dist(np.arange(10), np.ones(10)).logp(np.zeros(10)).eval()

assert np.array_equal(res, exp_res)