Skip to content

Commit 6404805

Browse files
committed
Add model debug helper
1 parent 4c64eb9 commit 6404805

File tree

2 files changed

+225
-1
lines changed

2 files changed

+225
-1
lines changed

pymc/model.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16+
import sys
1617
import threading
1718
import types
1819
import warnings
@@ -24,6 +25,7 @@
2425
Callable,
2526
Dict,
2627
List,
28+
Literal,
2729
Optional,
2830
Sequence,
2931
Tuple,
@@ -39,13 +41,15 @@
3941
import pytensor.tensor as pt
4042
import scipy.sparse as sps
4143

44+
from pytensor.compile import DeepCopyOp, get_mode
4245
from pytensor.compile.sharedvalue import SharedVariable
4346
from pytensor.graph.basic import Constant, Variable, graph_inputs
4447
from pytensor.graph.fg import FunctionGraph
4548
from pytensor.scalar import Cast
4649
from pytensor.tensor.elemwise import Elemwise
4750
from pytensor.tensor.random.op import RandomVariable
4851
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
52+
from pytensor.tensor.random.type import RandomType
4953
from pytensor.tensor.sharedvar import ScalarSharedVariable
5054
from pytensor.tensor.var import TensorConstant, TensorVariable
5155

@@ -61,6 +65,7 @@
6165
)
6266
from pymc.initial_point import make_initial_point_fn
6367
from pymc.logprob.basic import joint_logp
68+
from pymc.logprob.utils import ParameterValueError
6469
from pymc.pytensorf import (
6570
PointFunc,
6671
SeedSequenceSeed,
@@ -1779,7 +1784,8 @@ def check_start_vals(self, start):
17791784
raise SamplingError(
17801785
"Initial evaluation of model at starting point failed!\n"
17811786
f"Starting values:\n{elem}\n\n"
1782-
f"Logp initial evaluation results:\n{initial_eval}"
1787+
f"Logp initial evaluation results:\n{initial_eval}\n"
1788+
"You can call `model.debug()` for more details."
17831789
)
17841790

17851791
def point_logps(self, point=None, round_vals=2):
@@ -1811,6 +1817,152 @@ def point_logps(self, point=None, round_vals=2):
18111817
)
18121818
}
18131819

1820+
def debug(
1821+
self,
1822+
point: Optional[Dict[str, np.ndarray]] = None,
1823+
fn: Literal["logp", "dlogp", "random"] = "logp",
1824+
verbose: bool = False,
1825+
):
1826+
"""Debug model function at point.
1827+
1828+
The method will evaluate the `fn` for each variable at a time.
1829+
When an evaluation fails or produces a non-finite value we print:
1830+
1. The graph of the parameters
1831+
2. The value of the parameters (if those can be evaluated)
1832+
3. The output of `fn` (if it can be evaluated)
1833+
1834+
This function should help to quickly narrow down invalid parametrizations.
1835+
1836+
Parameters
1837+
----------
1838+
point : Point
1839+
Point at which model function should be evaluated
1840+
fn : str, default "logp"
1841+
Function to be used for debugging. Can be one of [logp, dlogp, random].
1842+
verbose : bool, default False
1843+
Whether to show a more verbose PyTensor output when function cannot be evaluated
1844+
"""
1845+
print_ = functools.partial(print, file=sys.stdout)
1846+
1847+
def first_line(exc):
1848+
return exc.args[0].split("\n")[0]
1849+
1850+
def debug_parameters(rv):
1851+
if isinstance(rv.owner.op, RandomVariable):
1852+
inputs = rv.owner.inputs[3:]
1853+
else:
1854+
inputs = [inp for inp in rv.owner.inputs if not isinstance(inp.type, RandomType)]
1855+
rv_inputs = pytensor.function(
1856+
self.value_vars,
1857+
self.replace_rvs_by_values(inputs),
1858+
on_unused_input="ignore",
1859+
mode=get_mode(None).excluding("inplace", "fusion"),
1860+
)
1861+
1862+
print_(f"The variable {rv} has the following parameters:")
1863+
# done and used_ids are used to keep the same ids across distinct dprint calls
1864+
done = {}
1865+
used_ids = {}
1866+
for i, out in enumerate(rv_inputs.maker.fgraph.outputs):
1867+
print_(f"{i}: ", end=""),
1868+
# Don't print useless deepcopys
1869+
if out.owner and isinstance(out.owner.op, DeepCopyOp):
1870+
out = out.owner.inputs[0]
1871+
pytensor.dprint(out, print_type=True, done=done, used_ids=used_ids)
1872+
1873+
try:
1874+
print_("The parameters evaluate to:")
1875+
for i, rv_input_eval in enumerate(rv_inputs(**point)):
1876+
print_(f"{i}: {rv_input_eval}")
1877+
except Exception as exc:
1878+
print_(
1879+
f"The parameters of the variable {rv} cannot be evaluated: {first_line(exc)}"
1880+
)
1881+
if verbose:
1882+
print_(exc, "\n")
1883+
1884+
if fn not in ("logp", "dlogp", "random"):
1885+
raise ValueError(f"fn must be one of [logp, dlogp, random], got {fn}")
1886+
1887+
if point is None:
1888+
point = self.initial_point()
1889+
print_(f"point={point}\n")
1890+
1891+
rvs_to_check = list(self.basic_RVs)
1892+
if fn in ("logp", "dlogp"):
1893+
rvs_to_check += [self.replace_rvs_by_values(p) for p in self.potentials]
1894+
1895+
found_problem = False
1896+
for rv in rvs_to_check:
1897+
if fn == "logp":
1898+
rv_fn = pytensor.function(
1899+
self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
1900+
)
1901+
elif fn == "dlogp":
1902+
rv_fn = pytensor.function(
1903+
self.value_vars, self.dlogp(vars=rv), on_unused_input="ignore"
1904+
)
1905+
else:
1906+
[rv_inputs_replaced] = replace_rvs_by_values(
1907+
[rv],
1908+
# Don't include itself, or the function will just the the value variable
1909+
rvs_to_values={
1910+
rv_key: value
1911+
for rv_key, value in self.rvs_to_values.items()
1912+
if rv_key is not rv
1913+
},
1914+
rvs_to_transforms=self.rvs_to_transforms,
1915+
)
1916+
rv_fn = pytensor.function(
1917+
self.value_vars, rv_inputs_replaced, on_unused_input="ignore"
1918+
)
1919+
1920+
try:
1921+
rv_fn_eval = rv_fn(**point)
1922+
except ParameterValueError as exc:
1923+
found_problem = True
1924+
debug_parameters(rv)
1925+
print_(
1926+
f"This does not respect one of the following constraints: {first_line(exc)}\n"
1927+
)
1928+
if verbose:
1929+
print_(exc)
1930+
except Exception as exc:
1931+
found_problem = True
1932+
debug_parameters(rv)
1933+
print_(
1934+
f"The variable {rv} {fn} method raised the following exception: {first_line(exc)}\n"
1935+
)
1936+
if verbose:
1937+
print_(exc)
1938+
else:
1939+
if not np.all(np.isfinite(rv_fn_eval)):
1940+
found_problem = True
1941+
debug_parameters(rv)
1942+
if fn == "random" or rv is self.potentials:
1943+
print_("This combination seems able to generate non-finite values")
1944+
else:
1945+
# Find which values are associated with non-finite evaluation
1946+
values = self.rvs_to_values[rv]
1947+
if rv in self.observed_RVs:
1948+
values = values.eval()
1949+
else:
1950+
values = point[values.name]
1951+
1952+
observed = " observed " if rv in self.observed_RVs else " "
1953+
print_(
1954+
f"Some of the{observed}values of variable {rv} are associated with a non-finite {fn}:"
1955+
)
1956+
mask = ~np.isfinite(rv_fn_eval)
1957+
for value, fn_eval in zip(values[mask], rv_fn_eval[mask]):
1958+
print_(f" value = {value} -> {fn} = {fn_eval}")
1959+
print_()
1960+
1961+
if not found_problem:
1962+
print_("No problems found")
1963+
elif not verbose:
1964+
print_("You can set `verbose=True` for more details")
1965+
18141966

18151967
# this is really disgusting, but it breaks a self-loop: I can't pass Model
18161968
# itself as context class init arg.

tests/test_model.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import scipy.stats as st
3131

3232
from pytensor.graph import graph_inputs
33+
from pytensor.raise_op import Assert, assert_op
3334
from pytensor.tensor import TensorVariable
3435
from pytensor.tensor.random.op import RandomVariable
3536
from pytensor.tensor.sharedvar import ScalarSharedVariable
@@ -1553,3 +1554,74 @@ def test_tag_future_warning_model():
15531554
assert y_value.eval() == 5
15541555

15551556
assert isinstance(y_value.tag, _FutureWarningValidatingScratchpad)
1557+
1558+
1559+
class TestModelDebug:
1560+
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
1561+
def test_no_problems(self, fn, capfd):
1562+
with pm.Model() as m:
1563+
x = pm.Normal("x", [1, -1, 1])
1564+
m.debug(fn=fn)
1565+
1566+
out, _ = capfd.readouterr()
1567+
assert out == "point={'x': array([ 1., -1., 1.])}\n\nNo problems found\n"
1568+
1569+
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
1570+
def test_invalid_parameter(self, fn, capfd):
1571+
with pm.Model() as m:
1572+
x = pm.Normal("x", [1, -1, 1])
1573+
y = pm.HalfNormal("y", tau=x)
1574+
m.debug(fn=fn)
1575+
1576+
out, _ = capfd.readouterr()
1577+
if fn == "dlogp":
1578+
# var dlogp is 0 or 1 without a likelihood
1579+
assert "No problems found" in out
1580+
else:
1581+
assert "The parameters evaluate to:\n0: 0.0\n1: [ 1. -1. 1.]" in out
1582+
if fn == "logp":
1583+
assert "This does not respect one of the following constraints: sigma > 0" in out
1584+
else:
1585+
assert (
1586+
"The variable y random method raised the following exception: Domain error in arguments."
1587+
in out
1588+
)
1589+
1590+
@pytest.mark.parametrize("verbose", (True, False))
1591+
@pytest.mark.parametrize("fn", ("logp", "dlogp", "random"))
1592+
def test_invalid_parameter_cant_be_evaluated(self, fn, verbose, capfd):
1593+
with pm.Model() as m:
1594+
x = pm.Normal("x", [1, 1, 1])
1595+
sigma = Assert(msg="x > 0")(pm.math.abs(x), (x > 0).all())
1596+
y = pm.HalfNormal("y", sigma=sigma)
1597+
m.debug(point={"x": [-1, -1, -1], "y_log__": [0, 0, 0]}, fn=fn, verbose=verbose)
1598+
1599+
out, _ = capfd.readouterr()
1600+
assert "{'x': [-1, -1, -1], 'y_log__': [0, 0, 0]}" in out
1601+
assert "The parameters of the variable y cannot be evaluated: x > 0" in out
1602+
verbose_str = "Apply node that caused the error:" in out
1603+
assert verbose_str if verbose else not verbose_str
1604+
1605+
def test_invalid_value(self, capfd):
1606+
with pm.Model() as m:
1607+
x = pm.Normal("x", [1, -1, 1])
1608+
y = pm.HalfNormal("y", tau=pm.math.abs(x), initval=[-1, 1, -1], transform=None)
1609+
m.debug()
1610+
1611+
out, _ = capfd.readouterr()
1612+
assert "The parameters of the variable y evaluate to:\n0: array(0., dtype=float32)\n1: array([1., 1., 1.])]"
1613+
assert "Some of the values of variable y are associated with a non-finite logp" in out
1614+
assert "value = -1.0 -> logp = -inf" in out
1615+
1616+
def test_invalid_observed_value(self, capfd):
1617+
with pm.Model() as m:
1618+
theta = pm.Uniform("theta", lower=0, upper=1)
1619+
y = pm.Uniform("y", lower=0, upper=theta, observed=[0.49, 0.27, 0.53, 0.19])
1620+
m.debug()
1621+
1622+
out, _ = capfd.readouterr()
1623+
assert "The parameters of the variable y evaluate to:\n0: 0.0\n1: 0.5"
1624+
assert (
1625+
"Some of the observed values of variable y are associated with a non-finite logp" in out
1626+
)
1627+
assert "value = 0.53 -> logp = -inf" in out

0 commit comments

Comments
 (0)