Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
|Tests Status| |Coverage|


``aeppl`` provides tools for a[e]PPL written in `Aesara <https://github.com/pymc-devs/aesara>`_.


Features
========
- Convert graphs containing Aesara ``RandomVariable`` into joint log-probability graphs
- Tools for traversing and transforming graphs containing ``RandomVariable``
- ``RandomVariable``-aware pretty printing and LaTeX output


Examples
========

.. code-block:: python

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint


# A simple scale mixture model
S_rv = at.random.invgamma(0.5, 0.5)
Y_rv = at.random.normal(0.0, at.sqrt(S_rv))

pprint(Y_rv)
# S ~ invgamma(0.5, 0.5) in R, Y ~ N(0.0, sqrt(S)**2) in R
# Y


# Compute the joint log-probability
y = at.scalar("y")
s = at.scalar("s")
logprob = joint_logprob(Y_rv, {Y_rv: y, S_rv: s})


# Simplify the graph so that it's easier to read
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.basic_opt import topo_constant_folding


logprob = optimize_graph(logprob, custom_opt=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
# ((-0.9189385175704956 +
# switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
# -inf) +
# ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))


# Create a finite mixture model with a Bernoulli distributed
# mixing distribution
Z_rv = at.random.normal([-100, 100], 1.0, name="Z")
I_rv = at.random.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

z = at.vector("z")
i = at.lscalar("i")
m = at.scalar("m")
# Compute the joint log-probability for the mixture
logprob = joint_logprob(M_rv, {M_rv: m, Z_rv: z, I_rv: i})


logprob = optimize_graph(logprob, custom_opt=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
# ((-0.9189385332046727 + (-0.5 * (((m - [-100 100][a]) / [1. 1.][a]) ** 2))) -
# log([1. 1.][a])))


.. |Tests Status| image:: https://github.com/aesara-devs/aeppl/actions/workflows/test.yml/badge.svg?branch=main
:target: https://github.com/aesara-devs/aeppl/actions/workflows/test.yml
.. |Coverage| image:: https://codecov.io/gh/aesara-devs/aeppl/branch/main/graph/badge.svg?token=L2i59LsFc0
:target: https://codecov.io/gh/aesara-devs/aeppl
6 changes: 6 additions & 0 deletions aeppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@

__version__ = get_versions()["version"]
del get_versions


from .logprob import logprob # isort: split

from .joint_logprob import joint_logprob
from .printing import latex_pprint, pprint
160 changes: 160 additions & 0 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import warnings
from collections import deque
from typing import Dict, Optional

from aesara import config
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorVariable

from aeppl.logprob import _logprob
from aeppl.opt import PreserveRVMappings, RVSinker
from aeppl.utils import rvs_to_value_vars


def joint_logprob(
var: TensorVariable,
rv_values: Optional[Dict[TensorVariable, TensorVariable]] = None,
warn_missing_rvs=True,
**kwargs,
) -> TensorVariable:
r"""Create a graph representing the joint log-probability/measure of a graph.

The input `var` determines which graph is used and `rv_values` specifies
the resulting measure-space graph's input parameters.

For example, consider the following

.. code-block:: python

import aesara.tensor as at

Y_rv = at.random.normal(0, at.sqrt(sigma2_rv))
sigma2_rv = at.random.invgamma(0.5, 0.5)

This graph for ``Y_rv`` is equivalent to the following hierarchical model:

.. math::

Y \sim& \operatorname{N}(0, \sigma^2) \\
\sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5)

If we create a value variable for ``Y_rv``, i.e. ``y = at.scalar("y")``,
the graph of ``joint_logprob(Y_rv, {Y_rv: y})`` is equivalent to the
conditional probability :math:`\log p(Y = y \mid \sigma^2)`. If we specify
a value variable for ``sigma2_rv``, i.e. ``s = at.scalar("s2")``, then
``joint_logprob(Y_rv, {Y_rv: y, sigma2_rv: s})`` yields the joint
log-probability

.. math::

\log p(Y = y, \sigma^2 = s) =
\log p(Y = y \mid \sigma^2 = s) + \log p(\sigma^2 = s)


Parameters
==========
var
The graph containing the stochastic/`RandomVariable` elements for
which we want to compute a joint log-probability. This graph
effectively represents a statistical model.
rv_values
A ``dict`` of variables that maps stochastic elements
(e.g. `RandomVariable`\s) to symbolic `Variable`\s representing their
values in a log-probability.
warn_missing_rvs
When ``True``, issue a warning when a `RandomVariable` is found in
the graph and doesn't have a corresponding value variable specified in
`rv_values`.

"""
# Since we're going to clone the entire graph, we need to keep a map from
# the old nodes to the new ones; otherwise, we won't be able to use
# `rv_values`.
# We start the `dict` with mappings from the value variables to themselves,
# to prevent them from being cloned.
memo = {v: v for v in rv_values.values()}

# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
# to give good warnings when an unaccounted for `RandomVairiable` is
# encountered
fgraph = FunctionGraph(
outputs=[var],
clone=True,
memo=memo,
copy_orphans=False,
features=[ShapeFeature()],
)

# Update `rv_values` so that it uses the new cloned variables
rv_values = {memo[k]: v for k, v in rv_values.items()}

# This `Feature` preserves the relationships between the original
# random variables (i.e. keys in `rv_values`) and the new ones
# produced when `Op`s are lifted through them.
rv_remapper = PreserveRVMappings(rv_values)

fgraph.attach_feature(rv_remapper)

_ = optimize_graph(fgraph, custom_opt=RVSinker())

# This is the updated random-to-value-vars map with the
# lifted variables
lifted_rv_values = rv_remapper.rv_values
replacements = lifted_rv_values.copy()

# Walk the graph from its inputs to its outputs and construct the
# log-probability
q = deque(fgraph.toposort())

logprob_var = None

while q:
node = q.popleft()

if not any(o in lifted_rv_values for o in node.outputs):
if isinstance(node.op, RandomVariable) and warn_missing_rvs:
warnings.warn(
"Found a random variable that was neither among the observations "
f"nor the conditioned variables: {node}"
)
continue

if isinstance(node.op, RandomVariable):
q_rv_var = node.outputs[1]
q_rv_value_var = replacements[q_rv_var]

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map so that we
# don't need to redo these replacements.
value_var_inputs, _ = rvs_to_value_vars(
node.inputs,
initial_replacements=replacements,
)

q_logprob_var = _logprob(
node.op, q_rv_value_var, *value_var_inputs, **kwargs
)

else:
raise NotImplementedError(
f"A measure/probability could not be derived for {node}"
)

if logprob_var is None:
logprob_var = q_logprob_var
else:
logprob_var += q_logprob_var

# Recompute test values for the changes introduced by the replacements
# above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs((logprob_var,)), (logprob_var,)):
compute_test_value(node)

return logprob_var
Loading