-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement logp for add
and mul
ops involving random variables
#4653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e8f9602
4cb0d1b
69f5caa
e8d5aee
05fe251
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
from aesara.graph.fg import FunctionGraph | ||
from aesara.graph.op import Op, compute_test_value | ||
from aesara.graph.type import CType | ||
from aesara.scalar.basic import Add, Mul | ||
from aesara.tensor.elemwise import Elemwise | ||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.random.opt import local_subtensor_rv_lift | ||
from aesara.tensor.subtensor import ( | ||
|
@@ -37,7 +39,13 @@ | |
) | ||
from aesara.tensor.var import TensorVariable | ||
|
||
from pymc3.aesaraf import extract_rv_and_value_vars, floatX, rvs_to_value_vars | ||
from pymc3.aesaraf import ( | ||
change_rv_size, | ||
extract_rv_and_value_vars, | ||
floatX, | ||
rvs_to_value_vars, | ||
walk_model, | ||
) | ||
|
||
|
||
@singledispatch | ||
|
@@ -258,8 +266,101 @@ def _logp( | |
The default assumes that the log-likelihood of a term is a zero. | ||
|
||
""" | ||
value_var = rvs_to_values.get(var, var) | ||
return at.zeros_like(value_var) | ||
# value_var = rvs_to_values.get(var, var) | ||
# return at.zeros_like(value_var) | ||
raise NotImplementedError(f"Logp cannot be computed for op {op}") | ||
|
||
|
||
@_logp.register(Elemwise) | ||
def elemwise_logp(op, *args, **kwargs): | ||
return _logp(op.scalar_op, *args, **kwargs) | ||
|
||
|
||
# TODO: Implement DimShuffle logp? | ||
# @_logp.register(DimShuffle) | ||
# def logp_dimshuffle(op, var, *args, **kwargs): | ||
# if var.owner and len(var.owner.inputs) == 1: | ||
# inp = var.owner.inputs[0] | ||
# if inp.owner and hasattr(inp.owner, 'op'): | ||
# return _logp(inp.owner.op, inp, *args, **kwargs) | ||
# raise NotImplementedError | ||
|
||
|
||
@_logp.register(Add) | ||
@_logp.register(Mul) | ||
def linear_logp( | ||
op, | ||
var, | ||
rvs_to_values, | ||
*linear_inputs, | ||
transformed=True, | ||
sum=False, | ||
**kwargs, | ||
): | ||
|
||
if len(linear_inputs) != 2: | ||
raise ValueError(f"Expected 2 inputs but got: {len(linear_inputs)}") | ||
Comment on lines
+301
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an implementation limitation and not a misspecification of any kind, so, if we're not going to support more than two arguments for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The scalar |
||
|
||
# Find base_rv and constant inputs | ||
base_rv = [] | ||
constant = [] | ||
for inp in linear_inputs: | ||
res_ancestors = list(walk_model((inp,), walk_past_rvs=True)) | ||
# unregistered variables do not contain a value_var tag | ||
res_unregistered_ancestors = [ | ||
v | ||
for v in res_ancestors | ||
if v.owner | ||
and isinstance(v.owner.op, RandomVariable) | ||
and not getattr(v.tag, "value_var", False) | ||
] | ||
if res_unregistered_ancestors: | ||
base_rv.append(inp) | ||
else: | ||
constant.append(inp) | ||
|
||
if len(base_rv) != 1: | ||
raise NotImplementedError( | ||
f"Logp of linear transform requires one branch with an unregistered RandomVariable but got {len(base_rv)}" | ||
) | ||
|
||
base_rv = base_rv[0] | ||
constant = constant[0] | ||
var_value = rvs_to_values.get(var, var) | ||
|
||
# Get logp of base_rv with transformed input | ||
if isinstance(op, Add): | ||
base_value = var_value - constant | ||
else: | ||
base_value = var_value / constant | ||
|
||
# Change base rv shape if needed | ||
if isinstance(base_rv.owner.op, RandomVariable): | ||
ndim_supp = base_rv.owner.op.ndim_supp | ||
if ndim_supp > 0: | ||
new_size = base_value.shape[:-ndim_supp] | ||
else: | ||
new_size = base_value.shape | ||
base_rv = change_rv_size(base_rv, new_size) | ||
|
||
var_logp = logpt(base_rv, {base_rv: base_value}, transformed=transformed, sum=False, **kwargs) | ||
|
||
# Apply product jacobian correction for continuous rvs | ||
if isinstance(op, Mul) and "float" in base_rv.dtype: | ||
var_logp -= at.log(at.abs_(constant)) | ||
|
||
# Replace rvs in graph | ||
(var_logp,), _ = rvs_to_value_vars( | ||
(var_logp,), | ||
apply_transforms=transformed, | ||
initial_replacements=None, | ||
) | ||
|
||
if sum: | ||
var_logp = at.sum(var_logp) | ||
|
||
var_logp.name = f"__logp_{var.name}" | ||
return var_logp | ||
|
||
|
||
def convert_indices(indices, entry): | ||
|
@@ -318,8 +419,7 @@ def subtensor_logp(op, var, rvs_to_values, indexed_rv_var, *indices, **kwargs): | |
# subset of variables per the index. | ||
var_copy = var.owner.clone().default_output() | ||
fgraph = FunctionGraph( | ||
[i for i in graph_inputs((indexed_rv_var,)) if not isinstance(i, Constant)], | ||
[var_copy], | ||
outputs=[var_copy], | ||
clone=False, | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is temporary, just to catch errors with DimShuffle ops. It causes the
test_logpt_subtensor
to fail