diff --git a/pymc3/distributions/logp.py b/pymc3/distributions/logp.py index e265cbb937..5229ea2d07 100644 --- a/pymc3/distributions/logp.py +++ b/pymc3/distributions/logp.py @@ -25,6 +25,8 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op, compute_test_value from aesara.graph.type import CType +from aesara.scalar.basic import Add, Mul +from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.opt import local_subtensor_rv_lift from aesara.tensor.subtensor import ( @@ -37,7 +39,13 @@ ) from aesara.tensor.var import TensorVariable -from pymc3.aesaraf import extract_rv_and_value_vars, floatX, rvs_to_value_vars +from pymc3.aesaraf import ( + change_rv_size, + extract_rv_and_value_vars, + floatX, + rvs_to_value_vars, + walk_model, +) @singledispatch @@ -258,8 +266,101 @@ def _logp( The default assumes that the log-likelihood of a term is a zero. """ - value_var = rvs_to_values.get(var, var) - return at.zeros_like(value_var) + # value_var = rvs_to_values.get(var, var) + # return at.zeros_like(value_var) + raise NotImplementedError(f"Logp cannot be computed for op {op}") + + +@_logp.register(Elemwise) +def elemwise_logp(op, *args, **kwargs): + return _logp(op.scalar_op, *args, **kwargs) + + +# TODO: Implement DimShuffle logp? +# @_logp.register(DimShuffle) +# def logp_dimshuffle(op, var, *args, **kwargs): +# if var.owner and len(var.owner.inputs) == 1: +# inp = var.owner.inputs[0] +# if inp.owner and hasattr(inp.owner, 'op'): +# return _logp(inp.owner.op, inp, *args, **kwargs) +# raise NotImplementedError + + +@_logp.register(Add) +@_logp.register(Mul) +def linear_logp( + op, + var, + rvs_to_values, + *linear_inputs, + transformed=True, + sum=False, + **kwargs, +): + + if len(linear_inputs) != 2: + raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}") + + # Find base_rv and constant inputs + base_rv = [] + constant = [] + for inp in linear_inputs: + res_ancestors = list(walk_model((inp,), walk_past_rvs=True)) + # unregistered variables do not contain a value_var tag + res_unregistered_ancestors = [ + v + for v in res_ancestors + if v.owner + and isinstance(v.owner.op, RandomVariable) + and not getattr(v.tag, "value_var", False) + ] + if res_unregistered_ancestors: + base_rv.append(inp) + else: + constant.append(inp) + + if len(base_rv) != 1: + raise NotImplementedError( + f"Logp of linear transform requires one branch with an unregistered RandomVariable but got {len(base_rv)}" + ) + + base_rv = base_rv[0] + constant = constant[0] + var_value = rvs_to_values.get(var, var) + + # Get logp of base_rv with transformed input + if isinstance(op, Add): + base_value = var_value - constant + else: + base_value = var_value / constant + + # Change base rv shape if needed + if isinstance(base_rv.owner.op, RandomVariable): + ndim_supp = base_rv.owner.op.ndim_supp + if ndim_supp > 0: + new_size = base_value.shape[:-ndim_supp] + else: + new_size = base_value.shape + base_rv = change_rv_size(base_rv, new_size) + + var_logp = logpt(base_rv, {base_rv: base_value}, transformed=transformed, sum=False, **kwargs) + + # Apply product jacobian correction for continuous rvs + if isinstance(op, Mul) and "float" in base_rv.dtype: + var_logp -= at.log(at.abs_(constant)) + + # Replace rvs in graph + (var_logp,), _ = rvs_to_value_vars( + (var_logp,), + apply_transforms=transformed, + initial_replacements=None, + ) + + if sum: + var_logp = at.sum(var_logp) + + var_logp.name = f"__logp_{var.name}" + return var_logp def convert_indices(indices, entry): @@ -318,8 +419,7 @@ def subtensor_logp(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs): # subset of variables per the index. var_copy = var.owner.clone().default_output() fgraph = FunctionGraph( - [i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)], - [var_copy], + outputs=[var_copy], clone=False, ) diff --git a/pymc3/tests/test_logp.py b/pymc3/tests/test_logp.py index aea9db1fdc..08b541f2ef 100644 --- a/pymc3/tests/test_logp.py +++ b/pymc3/tests/test_logp.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack as does_not_raise + import aesara import aesara.tensor as at import numpy as np @@ -31,7 +33,7 @@ ) from pymc3.aesaraf import floatX, walk_model -from pymc3.distributions.continuous import Normal, Uniform +from pymc3.distributions.continuous import Exponential, Normal, Uniform from pymc3.distributions.discrete import Bernoulli from pymc3.distributions.logp import logpt from pymc3.model import Model @@ -69,6 +71,142 @@ def test_logpt_basic(): assert a_value_var in res_ancestors +def test_logpt_add(): + """ + Mare sure we can compute a log-likelihood for ``loc + Y`` where ``Y`` is an unregistered + random variable and ``loc`` is an tensor variable or a registered random variable + """ + with Model() as m: + loc = Exponential("loc", 10) + x = Normal.dist(0, 1) + loc + m.register_rv(x, "x") + + loc_value_var = m.rvs_to_values[loc] + x_value_var = m.rvs_to_values[x] + + x_logp = logpt(x, m.rvs_to_values[x]) + + res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True)) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert loc_value_var in res_ancestors + assert x_value_var in res_ancestors + + # Test logp is correct + f_logp = aesara.function([x_value_var, loc_value_var], x_logp) + np.testing.assert_almost_equal(f_logp(50, np.log(50)), sp.norm(50, 1).logpdf(50)) + np.testing.assert_almost_equal(f_logp(50, np.log(10)), sp.norm(10, 1).logpdf(50), decimal=5) + + +def test_logpt_mul(): + """ + Mare sure we can compute a log-likelihood for ``scale * Y`` where ``Y`` is an unregistered + random variable and ``scale`` is an tensor variable or a registered random variable + """ + with Model() as m: + scale = Exponential("scale", 10) + x = Exponential.dist(1) * scale + m.register_rv(x, "x") + + scale_value_var = m.rvs_to_values[scale] + x_value_var = m.rvs_to_values[x] + + x_logp = logpt(x, m.rvs_to_values[x]) + + res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True)) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert scale_value_var in res_ancestors + assert x_value_var in res_ancestors + + # Test logp is correct + f_logp = aesara.function([x_value_var, scale_value_var], x_logp) + np.testing.assert_almost_equal(f_logp(0, np.log(5)), sp.expon(scale=5).logpdf(0)) + np.testing.assert_almost_equal(f_logp(2, np.log(2)), sp.expon(scale=2).logpdf(2)) + + +def test_logpt_mul_add(): + """ + Mare sure we can compute a log-likelihood for ``loc + scale * Y`` where ``Y`` is an unregistered + random variable and ``loc`` and ``scale`` are tensor variables or registered random variables + """ + with Model() as m: + loc = Exponential("loc", 10) + scale = Exponential("scale", 10) + x = loc + scale * Normal.dist(0, 1) + m.register_rv(x, "x") + + loc_value_var = m.rvs_to_values[loc] + scale_value_var = m.rvs_to_values[scale] + x_value_var = m.rvs_to_values[x] + + x_logp = logpt(x, m.rvs_to_values[x]) + + res_ancestors = list(walk_model((x_logp,), walk_past_rvs=True)) + res_rv_ancestors = [ + v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable) + ] + + # There shouldn't be any `RandomVariable`s in the resulting graph + assert len(res_rv_ancestors) == 0 + assert loc_value_var in res_ancestors + assert scale_value_var in res_ancestors + assert x_value_var in res_ancestors + + # Test logp is correct + f_logp = aesara.function([x_value_var, loc_value_var, scale_value_var], x_logp) + np.testing.assert_almost_equal(f_logp(-1, np.log(0), np.log(2)), sp.norm(0, 2).logpdf(-1)) + np.testing.assert_almost_equal( + f_logp(95, np.log(100), np.log(15)), sp.norm(100, 15).logpdf(95), decimal=6 + ) + + +@pytest.mark.parametrize("op", [at.add, at.mul]) +def test_logpt_not_implemented(op): + """Test that logpt for add and mul fail if inputs are 0 or 2 unregistered rvs""" + + with Model() as m: + variable1 = at.as_tensor_variable(1, "variable1") + variable2 = at.scalar("variable2") + unregistered1 = Normal.dist(0, 1) + unregistered2 = Normal.dist(0, 1) + registered1 = Normal("registered1", 0, 1) + registered2 = Normal("registered2", 0, 1) + + x_fail1 = op(variable1, variable2) + x_fail2 = op(unregistered1, unregistered2) + x_fail3 = op(registered1, variable1) + x_fail4 = op(registered1, registered2) + + x_pass1 = op(variable1, unregistered2) + x_pass2 = op(unregistered1, variable2) + x_pass3 = op(registered1, unregistered1) + + m.register_rv(x_fail1, "x_fail1") + m.register_rv(x_fail2, "x_fail2") + m.register_rv(x_fail3, "x_fail3") + m.register_rv(x_fail4, "x_fail4") + m.register_rv(x_pass1, "x_pass1") + m.register_rv(x_pass2, "x_pass2") + m.register_rv(x_pass3, "x_pass3") + + for rv, value_var in m.rvs_to_values.items(): + if "fail" in rv.name: + with pytest.raises(NotImplementedError): + logpt(rv, value_var) + else: + with does_not_raise(): + logpt(rv, value_var) + + @pytest.mark.parametrize( "indices, size", [