diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index e8611c591e..66353f8200 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -195,33 +195,26 @@ def logpt( getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim ) - # Unlike aeppl, PyMC's logpt is expected to plug in the values variables to corresponding - # RVs automatically unless the values are explicity set to None. Hence we iterate through - # the graph to find RVs and construct a new RVs to values dictionary. + # Aeppl needs all rv-values pairs, not just that of the requested var. + # Hence we iterate through the graph to collect them. tmp_rvs_to_values = rv_values.copy() transform_map = {} for node in io_toposort(graph_inputs(var), var): - if isinstance(node.op, RandomVariable): - curr_var = node.out + try: + curr_vars = [node.default_output()] + except ValueError: + curr_vars = node.outputs + for curr_var in curr_vars: rv_value_var = getattr( - curr_var.tag, "observations", getattr(curr_var.tag, "value_var", curr_var) + curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None) ) + if rv_value_var is None: + continue rv_value = rv_values.get(curr_var, rv_value_var) tmp_rvs_to_values[curr_var] = rv_value # Along with value variables we also check for transforms if any. if hasattr(rv_value_var.tag, "transform") and transformed: transform_map[rv_value] = rv_value_var.tag.transform - # The condition below is a hackish way of excluding the value variable for the - # RV being indexed in case of Advanced Indexing of RVs. It gets added by the - # logic above but aeppl does not expect us to include it in the dictionary of - # {RV:values} given to it. - if isinstance(node.op, subtensor_types): - curr_var = node.out - if ( - curr_var in tmp_rvs_to_values.keys() - and curr_var.owner.inputs[0] in tmp_rvs_to_values.keys() - ): - tmp_rvs_to_values.pop(curr_var.owner.inputs[0]) transform_opt = TransformValuesOpt(transform_map) temp_logp_var_dict = factorized_joint_logprob( diff --git a/pymc/model.py b/pymc/model.py index 39d44db09e..3d9541cf39 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -714,54 +714,24 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): raise ValueError(f"Can only compute the gradient of continuous types: {var}") if tempered: - with self: - # Convert random variables into their log-likelihood inputs and - # apply their transforms, if any - potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) - - free_RVs_logp = at.sum( - [at.sum(logpt(var, self.rvs_to_values.get(var, None))) for var in self.free_RVs] - + list(potentials) - ) - observed_RVs_logp = at.sum( - [at.sum(logpt(obs, obs.tag.observations)) for obs in self.observed_RVs] - ) - - costs = [free_RVs_logp, observed_RVs_logp] + # TODO: Should this differ from self.datalogpt, + # where the potential terms are added to the observations? + costs = [self.varlogpt + self.potentiallogpt, self.observedlogpt] else: costs = [self.logpt] input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)} extra_vars = [self.rvs_to_values.get(var, var) for var in self.free_RVs] + ip = self.recompute_initial_point(0) extra_vars_and_values = { - var: self.initial_point[var.name] - for var in extra_vars - if var in input_vars and var not in grad_vars + var: ip[var.name] for var in extra_vars if var in input_vars and var not in grad_vars } return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs) @property def logpt(self): """Aesara scalar of log-probability of the model""" - - rv_values = {} - for var in self.free_RVs: - rv_values[var] = self.rvs_to_values.get(var, None) - rv_factors = logpt(self.free_RVs, rv_values) - - obs_values = {} - for obs in self.observed_RVs: - obs_values[obs] = obs.tag.observations - obs_factors = logpt(self.observed_RVs, obs_values) - - # Convert random variables into their log-likelihood inputs and - # apply their transforms, if any - potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) - logp_var = at.sum([at.sum(factor) for factor in potentials]) - if rv_factors is not None: - logp_var += rv_factors - if obs_factors is not None: - logp_var += obs_factors + logp_var = self.varlogpt + self.datalogpt if self.name: logp_var.name = f"__logp_{self.name}" @@ -777,60 +747,65 @@ def logp_nojact(self): Note that if there is no transformed variable in the model, logp_nojact will be the same as logpt as there is no need for Jacobian correction. """ - with self: - rv_values = {} - for var in self.free_RVs: - rv_values[var] = getattr(var.tag, "value_var", None) - rv_factors = logpt(self.free_RVs, rv_values, jacobian=False) - - obs_values = {} - for obs in self.observed_RVs: - obs_values[obs] = obs.tag.observations - obs_factors = logpt(self.observed_RVs, obs_values, jacobian=False) - - # Convert random variables into their log-likelihood inputs and - # apply their transforms, if any - potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) - logp_var = at.sum([at.sum(factor) for factor in potentials]) - - if rv_factors is not None: - logp_var += rv_factors - if obs_factors is not None: - logp_var += obs_factors - - if self.name: - logp_var.name = f"__logp_nojac_{self.name}" - else: - logp_var.name = "__logp_nojac" - return logp_var + logp_var = self.varlogp_nojact + self.datalogpt + + if self.name: + logp_var.name = f"__logp_nojac_{self.name}" + else: + logp_var.name = "__logp_nojac" + return logp_var + + @property + def datalogpt(self): + """Aesara scalar of log-probability of the observed variables and + potential terms""" + return self.observedlogpt + self.potentiallogpt @property def varlogpt(self): """Aesara scalar of log-probability of the unobserved random variables (excluding deterministic).""" - with self: - rv_values = {} - for var in self.free_RVs: - rv_values[var] = getattr(var.tag, "value_var", None) + rv_values = {} + for var in self.free_RVs: + rv_values[var] = self.rvs_to_values[var] + if rv_values: return logpt(self.free_RVs, rv_values) + else: + return 0 @property - def datalogpt(self): - with self: - obs_values = {} - for obs in self.observed_RVs: - obs_values[obs] = obs.tag.observations - obs_factors = logpt(self.observed_RVs, obs_values) - - # Convert random variables into their log-likelihood inputs and - # apply their transforms, if any - potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) - logp_var = at.sum([at.sum(factor) for factor in potentials]) + def varlogp_nojact(self): + """Aesara scalar of log-probability of the unobserved random variables + (excluding deterministic) without jacobian term.""" + rv_values = {} + for var in self.free_RVs: + rv_values[var] = self.rvs_to_values[var] + if rv_values: + return logpt(self.free_RVs, rv_values, jacobian=False) + else: + return 0 - if obs_factors is not None: - logp_var += obs_factors + @property + def observedlogpt(self): + """Aesara scalar of log-probability of the observed variables""" + obs_values = {} + for obs in self.observed_RVs: + obs_values[obs] = obs.tag.observations + if obs_values: + return logpt(self.observed_RVs, obs_values) + else: + return 0 - return logp_var + @property + def potentiallogpt(self): + """Aesara scalar of log-probability of the Potential terms""" + # Convert random variables in Potential expression into their log-likelihood + # inputs and apply their transforms, if any + potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True) + if potentials: + return at.sum([at.sum(factor) for factor in potentials]) + else: + return 0 @property def vars(self):