Skip to content

Refactor logpt calls to aeppl #5166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,33 +195,26 @@ def logpt(
getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
)

# Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding
# RVs automatically unless the values are explicity set to None. Hence we iterate through
# the graph to find RVs and construct a new RVs to values dictionary.
# Aeppl needs all rv-values pairs, not just that of the requested var.
# Hence we iterate through the graph to collect them.
tmp_rvs_to_values = rv_values.copy()
transform_map = {}
for node in io_toposort(graph_inputs(var), var):
if isinstance(node.op, RandomVariable):
curr_var = node.out
try:
curr_vars = [node.default_output()]
except ValueError:
curr_vars = node.outputs
for curr_var in curr_vars:
rv_value_var = getattr(
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", curr_var)
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None)
)
if rv_value_var is None:
continue
rv_value = rv_values.get(curr_var, rv_value_var)
tmp_rvs_to_values[curr_var] = rv_value
# Along with value variables we also check for transforms if any.
if hasattr(rv_value_var.tag, "transform") and transformed:
transform_map[rv_value] = rv_value_var.tag.transform
# The condition below is a hackish way of excluding the value variable for the
# RV being indexed in case of Advanced Indexing of RVs. It gets added by the
# logic above but aeppl does not expect us to include it in the dictionary of
# {RV:values} given to it.
if isinstance(node.op, subtensor_types):
curr_var = node.out
if (
curr_var in tmp_rvs_to_values.keys()
and curr_var.owner.inputs[0] in tmp_rvs_to_values.keys()
):
tmp_rvs_to_values.pop(curr_var.owner.inputs[0])

transform_opt = TransformValuesOpt(transform_map)
temp_logp_var_dict = factorized_joint_logprob(
Expand Down
135 changes: 55 additions & 80 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,54 +714,24 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
raise ValueError(f"Can only compute the gradient of continuous types: {var}")

if tempered:
with self:
# Convert random variables into their log-likelihood inputs and
# apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)

free_RVs_logp = at.sum(
[at.sum(logpt(var, self.rvs_to_values.get(var, None))) for var in self.free_RVs]
+ list(potentials)
)
observed_RVs_logp = at.sum(
[at.sum(logpt(obs, obs.tag.observations)) for obs in self.observed_RVs]
)

costs = [free_RVs_logp, observed_RVs_logp]
# TODO: Should this differ from self.datalogpt,
# where the potential terms are added to the observations?
costs = [self.varlogpt + self.potentiallogpt, self.observedlogpt]
else:
costs = [self.logpt]

input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
extra_vars = [self.rvs_to_values.get(var, var) for var in self.free_RVs]
ip = self.recompute_initial_point(0)
extra_vars_and_values = {
var: self.initial_point[var.name]
for var in extra_vars
if var in input_vars and var not in grad_vars
var: ip[var.name] for var in extra_vars if var in input_vars and var not in grad_vars
}
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)

@property
def logpt(self):
"""Aesara scalar of log-probability of the model"""

rv_values = {}
for var in self.free_RVs:
rv_values[var] = self.rvs_to_values.get(var, None)
rv_factors = logpt(self.free_RVs, rv_values)

obs_values = {}
for obs in self.observed_RVs:
obs_values[obs] = obs.tag.observations
obs_factors = logpt(self.observed_RVs, obs_values)

# Convert random variables into their log-likelihood inputs and
# apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
logp_var = at.sum([at.sum(factor) for factor in potentials])
if rv_factors is not None:
logp_var += rv_factors
if obs_factors is not None:
logp_var += obs_factors
logp_var = self.varlogpt + self.datalogpt

if self.name:
logp_var.name = f"__logp_{self.name}"
Expand All @@ -777,60 +747,65 @@ def logp_nojact(self):
Note that if there is no transformed variable in the model, logp_nojact
will be the same as logpt as there is no need for Jacobian correction.
"""
with self:
rv_values = {}
for var in self.free_RVs:
rv_values[var] = getattr(var.tag, "value_var", None)
rv_factors = logpt(self.free_RVs, rv_values, jacobian=False)

obs_values = {}
for obs in self.observed_RVs:
obs_values[obs] = obs.tag.observations
obs_factors = logpt(self.observed_RVs, obs_values, jacobian=False)

# Convert random variables into their log-likelihood inputs and
# apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
logp_var = at.sum([at.sum(factor) for factor in potentials])

if rv_factors is not None:
logp_var += rv_factors
if obs_factors is not None:
logp_var += obs_factors

if self.name:
logp_var.name = f"__logp_nojac_{self.name}"
else:
logp_var.name = "__logp_nojac"
return logp_var
logp_var = self.varlogp_nojact + self.datalogpt

if self.name:
logp_var.name = f"__logp_nojac_{self.name}"
else:
logp_var.name = "__logp_nojac"
return logp_var

@property
def datalogpt(self):
"""Aesara scalar of log-probability of the observed variables and
potential terms"""
return self.observedlogpt + self.potentiallogpt

@property
def varlogpt(self):
"""Aesara scalar of log-probability of the unobserved random variables
(excluding deterministic)."""
with self:
rv_values = {}
for var in self.free_RVs:
rv_values[var] = getattr(var.tag, "value_var", None)
rv_values = {}
for var in self.free_RVs:
rv_values[var] = self.rvs_to_values[var]
if rv_values:
return logpt(self.free_RVs, rv_values)
else:
return 0

@property
def datalogpt(self):
with self:
obs_values = {}
for obs in self.observed_RVs:
obs_values[obs] = obs.tag.observations
obs_factors = logpt(self.observed_RVs, obs_values)

# Convert random variables into their log-likelihood inputs and
# apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
logp_var = at.sum([at.sum(factor) for factor in potentials])
def varlogp_nojact(self):
"""Aesara scalar of log-probability of the unobserved random variables
(excluding deterministic) without jacobian term."""
rv_values = {}
for var in self.free_RVs:
rv_values[var] = self.rvs_to_values[var]
if rv_values:
return logpt(self.free_RVs, rv_values, jacobian=False)
else:
return 0

if obs_factors is not None:
logp_var += obs_factors
@property
def observedlogpt(self):
"""Aesara scalar of log-probability of the observed variables"""
obs_values = {}
for obs in self.observed_RVs:
obs_values[obs] = obs.tag.observations
if obs_values:
return logpt(self.observed_RVs, obs_values)
else:
return 0

return logp_var
@property
def potentiallogpt(self):
"""Aesara scalar of log-probability of the Potential terms"""
# Convert random variables in Potential expression into their log-likelihood
# inputs and apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
if potentials:
return at.sum([at.sum(factor) for factor in potentials])
else:
return 0

@property
def vars(self):
Expand Down