Skip to content

Commit 31b4a37

Browse files
authored
Refactor logpt calls to aeppl (#5166)
* Simplify tmp_rvs_to_values generation in logpt and remove duplicated code across model logpt properties. This commit also introduces a varlogp_nojact, varlogp_nojact and potentiallogpt properties to Model objects
1 parent f24a1df commit 31b4a37

File tree

2 files changed

+65
-97
lines changed

2 files changed

+65
-97
lines changed

pymc/distributions/logprob.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,33 +195,26 @@ def logpt(
195195
getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
196196
)
197197

198-
# Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding
199-
# RVs automatically unless the values are explicity set to None. Hence we iterate through
200-
# the graph to find RVs and construct a new RVs to values dictionary.
198+
# Aeppl needs all rv-values pairs, not just that of the requested var.
199+
# Hence we iterate through the graph to collect them.
201200
tmp_rvs_to_values = rv_values.copy()
202201
transform_map = {}
203202
for node in io_toposort(graph_inputs(var), var):
204-
if isinstance(node.op, RandomVariable):
205-
curr_var = node.out
203+
try:
204+
curr_vars = [node.default_output()]
205+
except ValueError:
206+
curr_vars = node.outputs
207+
for curr_var in curr_vars:
206208
rv_value_var = getattr(
207-
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", curr_var)
209+
curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None)
208210
)
211+
if rv_value_var is None:
212+
continue
209213
rv_value = rv_values.get(curr_var, rv_value_var)
210214
tmp_rvs_to_values[curr_var] = rv_value
211215
# Along with value variables we also check for transforms if any.
212216
if hasattr(rv_value_var.tag, "transform") and transformed:
213217
transform_map[rv_value] = rv_value_var.tag.transform
214-
# The condition below is a hackish way of excluding the value variable for the
215-
# RV being indexed in case of Advanced Indexing of RVs. It gets added by the
216-
# logic above but aeppl does not expect us to include it in the dictionary of
217-
# {RV:values} given to it.
218-
if isinstance(node.op, subtensor_types):
219-
curr_var = node.out
220-
if (
221-
curr_var in tmp_rvs_to_values.keys()
222-
and curr_var.owner.inputs[0] in tmp_rvs_to_values.keys()
223-
):
224-
tmp_rvs_to_values.pop(curr_var.owner.inputs[0])
225218

226219
transform_opt = TransformValuesOpt(transform_map)
227220
temp_logp_var_dict = factorized_joint_logprob(

pymc/model.py

Lines changed: 55 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -714,54 +714,24 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
714714
raise ValueError(f"Can only compute the gradient of continuous types: {var}")
715715

716716
if tempered:
717-
with self:
718-
# Convert random variables into their log-likelihood inputs and
719-
# apply their transforms, if any
720-
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
721-
722-
free_RVs_logp = at.sum(
723-
[at.sum(logpt(var, self.rvs_to_values.get(var, None))) for var in self.free_RVs]
724-
+ list(potentials)
725-
)
726-
observed_RVs_logp = at.sum(
727-
[at.sum(logpt(obs, obs.tag.observations)) for obs in self.observed_RVs]
728-
)
729-
730-
costs = [free_RVs_logp, observed_RVs_logp]
717+
# TODO: Should this differ from self.datalogpt,
718+
# where the potential terms are added to the observations?
719+
costs = [self.varlogpt + self.potentiallogpt, self.observedlogpt]
731720
else:
732721
costs = [self.logpt]
733722

734723
input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
735724
extra_vars = [self.rvs_to_values.get(var, var) for var in self.free_RVs]
725+
ip = self.recompute_initial_point(0)
736726
extra_vars_and_values = {
737-
var: self.initial_point[var.name]
738-
for var in extra_vars
739-
if var in input_vars and var not in grad_vars
727+
var: ip[var.name] for var in extra_vars if var in input_vars and var not in grad_vars
740728
}
741729
return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)
742730

743731
@property
744732
def logpt(self):
745733
"""Aesara scalar of log-probability of the model"""
746-
747-
rv_values = {}
748-
for var in self.free_RVs:
749-
rv_values[var] = self.rvs_to_values.get(var, None)
750-
rv_factors = logpt(self.free_RVs, rv_values)
751-
752-
obs_values = {}
753-
for obs in self.observed_RVs:
754-
obs_values[obs] = obs.tag.observations
755-
obs_factors = logpt(self.observed_RVs, obs_values)
756-
757-
# Convert random variables into their log-likelihood inputs and
758-
# apply their transforms, if any
759-
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
760-
logp_var = at.sum([at.sum(factor) for factor in potentials])
761-
if rv_factors is not None:
762-
logp_var += rv_factors
763-
if obs_factors is not None:
764-
logp_var += obs_factors
734+
logp_var = self.varlogpt + self.datalogpt
765735

766736
if self.name:
767737
logp_var.name = f"__logp_{self.name}"
@@ -777,60 +747,65 @@ def logp_nojact(self):
777747
Note that if there is no transformed variable in the model, logp_nojact
778748
will be the same as logpt as there is no need for Jacobian correction.
779749
"""
780-
with self:
781-
rv_values = {}
782-
for var in self.free_RVs:
783-
rv_values[var] = getattr(var.tag, "value_var", None)
784-
rv_factors = logpt(self.free_RVs, rv_values, jacobian=False)
785-
786-
obs_values = {}
787-
for obs in self.observed_RVs:
788-
obs_values[obs] = obs.tag.observations
789-
obs_factors = logpt(self.observed_RVs, obs_values, jacobian=False)
790-
791-
# Convert random variables into their log-likelihood inputs and
792-
# apply their transforms, if any
793-
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
794-
logp_var = at.sum([at.sum(factor) for factor in potentials])
795-
796-
if rv_factors is not None:
797-
logp_var += rv_factors
798-
if obs_factors is not None:
799-
logp_var += obs_factors
800-
801-
if self.name:
802-
logp_var.name = f"__logp_nojac_{self.name}"
803-
else:
804-
logp_var.name = "__logp_nojac"
805-
return logp_var
750+
logp_var = self.varlogp_nojact + self.datalogpt
751+
752+
if self.name:
753+
logp_var.name = f"__logp_nojac_{self.name}"
754+
else:
755+
logp_var.name = "__logp_nojac"
756+
return logp_var
757+
758+
@property
759+
def datalogpt(self):
760+
"""Aesara scalar of log-probability of the observed variables and
761+
potential terms"""
762+
return self.observedlogpt + self.potentiallogpt
806763

807764
@property
808765
def varlogpt(self):
809766
"""Aesara scalar of log-probability of the unobserved random variables
810767
(excluding deterministic)."""
811-
with self:
812-
rv_values = {}
813-
for var in self.free_RVs:
814-
rv_values[var] = getattr(var.tag, "value_var", None)
768+
rv_values = {}
769+
for var in self.free_RVs:
770+
rv_values[var] = self.rvs_to_values[var]
771+
if rv_values:
815772
return logpt(self.free_RVs, rv_values)
773+
else:
774+
return 0
816775

817776
@property
818-
def datalogpt(self):
819-
with self:
820-
obs_values = {}
821-
for obs in self.observed_RVs:
822-
obs_values[obs] = obs.tag.observations
823-
obs_factors = logpt(self.observed_RVs, obs_values)
824-
825-
# Convert random variables into their log-likelihood inputs and
826-
# apply their transforms, if any
827-
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
828-
logp_var = at.sum([at.sum(factor) for factor in potentials])
777+
def varlogp_nojact(self):
778+
"""Aesara scalar of log-probability of the unobserved random variables
779+
(excluding deterministic) without jacobian term."""
780+
rv_values = {}
781+
for var in self.free_RVs:
782+
rv_values[var] = self.rvs_to_values[var]
783+
if rv_values:
784+
return logpt(self.free_RVs, rv_values, jacobian=False)
785+
else:
786+
return 0
829787

830-
if obs_factors is not None:
831-
logp_var += obs_factors
788+
@property
789+
def observedlogpt(self):
790+
"""Aesara scalar of log-probability of the observed variables"""
791+
obs_values = {}
792+
for obs in self.observed_RVs:
793+
obs_values[obs] = obs.tag.observations
794+
if obs_values:
795+
return logpt(self.observed_RVs, obs_values)
796+
else:
797+
return 0
832798

833-
return logp_var
799+
@property
800+
def potentiallogpt(self):
801+
"""Aesara scalar of log-probability of the Potential terms"""
802+
# Convert random variables in Potential expression into their log-likelihood
803+
# inputs and apply their transforms, if any
804+
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
805+
if potentials:
806+
return at.sum([at.sum(factor) for factor in potentials])
807+
else:
808+
return 0
834809

835810
@property
836811
def vars(self):

0 commit comments

Comments
 (0)