|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import functools
|
| 16 | +import sys |
16 | 17 | import threading
|
17 | 18 | import types
|
18 | 19 | import warnings
|
|
24 | 25 | Callable,
|
25 | 26 | Dict,
|
26 | 27 | List,
|
| 28 | + Literal, |
27 | 29 | Optional,
|
28 | 30 | Sequence,
|
29 | 31 | Tuple,
|
|
39 | 41 | import pytensor.tensor as pt
|
40 | 42 | import scipy.sparse as sps
|
41 | 43 |
|
| 44 | +from pytensor.compile import DeepCopyOp, get_mode |
42 | 45 | from pytensor.compile.sharedvalue import SharedVariable
|
43 | 46 | from pytensor.graph.basic import Constant, Variable, graph_inputs
|
44 | 47 | from pytensor.graph.fg import FunctionGraph
|
45 | 48 | from pytensor.scalar import Cast
|
46 | 49 | from pytensor.tensor.elemwise import Elemwise
|
47 | 50 | from pytensor.tensor.random.op import RandomVariable
|
48 | 51 | from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
|
| 52 | +from pytensor.tensor.random.type import RandomType |
49 | 53 | from pytensor.tensor.sharedvar import ScalarSharedVariable
|
50 | 54 | from pytensor.tensor.var import TensorConstant, TensorVariable
|
51 | 55 |
|
|
61 | 65 | )
|
62 | 66 | from pymc.initial_point import make_initial_point_fn
|
63 | 67 | from pymc.logprob.basic import joint_logp
|
| 68 | +from pymc.logprob.utils import ParameterValueError |
64 | 69 | from pymc.pytensorf import (
|
65 | 70 | PointFunc,
|
66 | 71 | SeedSequenceSeed,
|
@@ -1779,7 +1784,8 @@ def check_start_vals(self, start):
|
1779 | 1784 | raise SamplingError(
|
1780 | 1785 | "Initial evaluation of model at starting point failed!\n"
|
1781 | 1786 | f"Starting values:\n{elem}\n\n"
|
1782 |
| - f"Logp initial evaluation results:\n{initial_eval}" |
| 1787 | + f"Logp initial evaluation results:\n{initial_eval}\n" |
| 1788 | + "You can call `model.debug()` for more details." |
1783 | 1789 | )
|
1784 | 1790 |
|
1785 | 1791 | def point_logps(self, point=None, round_vals=2):
|
@@ -1811,6 +1817,152 @@ def point_logps(self, point=None, round_vals=2):
|
1811 | 1817 | )
|
1812 | 1818 | }
|
1813 | 1819 |
|
| 1820 | + def debug( |
| 1821 | + self, |
| 1822 | + point: Optional[Dict[str, np.ndarray]] = None, |
| 1823 | + fn: Literal["logp", "dlogp", "random"] = "logp", |
| 1824 | + verbose: bool = False, |
| 1825 | + ): |
| 1826 | + """Debug model function at point. |
| 1827 | +
|
| 1828 | + The method will evaluate the `fn` for each variable at a time. |
| 1829 | + When an evaluation fails or produces a non-finite value we print: |
| 1830 | + 1. The graph of the parameters |
| 1831 | + 2. The value of the parameters (if those can be evaluated) |
| 1832 | + 3. The output of `fn` (if it can be evaluated) |
| 1833 | +
|
| 1834 | + This function should help to quickly narrow down invalid parametrizations. |
| 1835 | +
|
| 1836 | + Parameters |
| 1837 | + ---------- |
| 1838 | + point : Point |
| 1839 | + Point at which model function should be evaluated |
| 1840 | + fn : str, default "logp" |
| 1841 | + Function to be used for debugging. Can be one of [logp, dlogp, random]. |
| 1842 | + verbose : bool, default False |
| 1843 | + Whether to show a more verbose PyTensor output when function cannot be evaluated |
| 1844 | + """ |
| 1845 | + print_ = functools.partial(print, file=sys.stdout) |
| 1846 | + |
| 1847 | + def first_line(exc): |
| 1848 | + return exc.args[0].split("\n")[0] |
| 1849 | + |
| 1850 | + def debug_parameters(rv): |
| 1851 | + if isinstance(rv.owner.op, RandomVariable): |
| 1852 | + inputs = rv.owner.inputs[3:] |
| 1853 | + else: |
| 1854 | + inputs = [inp for inp in rv.owner.inputs if not isinstance(inp.type, RandomType)] |
| 1855 | + rv_inputs = pytensor.function( |
| 1856 | + self.value_vars, |
| 1857 | + self.replace_rvs_by_values(inputs), |
| 1858 | + on_unused_input="ignore", |
| 1859 | + mode=get_mode(None).excluding("inplace", "fusion"), |
| 1860 | + ) |
| 1861 | + |
| 1862 | + print_(f"The variable {rv} has the following parameters:") |
| 1863 | + # done and used_ids are used to keep the same ids across distinct dprint calls |
| 1864 | + done = {} |
| 1865 | + used_ids = {} |
| 1866 | + for i, out in enumerate(rv_inputs.maker.fgraph.outputs): |
| 1867 | + print_(f"{i}: ", end=""), |
| 1868 | + # Don't print useless deepcopys |
| 1869 | + if out.owner and isinstance(out.owner.op, DeepCopyOp): |
| 1870 | + out = out.owner.inputs[0] |
| 1871 | + pytensor.dprint(out, print_type=True, done=done, used_ids=used_ids) |
| 1872 | + |
| 1873 | + try: |
| 1874 | + print_("The parameters evaluate to:") |
| 1875 | + for i, rv_input_eval in enumerate(rv_inputs(**point)): |
| 1876 | + print_(f"{i}: {rv_input_eval}") |
| 1877 | + except Exception as exc: |
| 1878 | + print_( |
| 1879 | + f"The parameters of the variable {rv} cannot be evaluated: {first_line(exc)}" |
| 1880 | + ) |
| 1881 | + if verbose: |
| 1882 | + print_(exc, "\n") |
| 1883 | + |
| 1884 | + if fn not in ("logp", "dlogp", "random"): |
| 1885 | + raise ValueError(f"fn must be one of [logp, dlogp, random], got {fn}") |
| 1886 | + |
| 1887 | + if point is None: |
| 1888 | + point = self.initial_point() |
| 1889 | + print_(f"point={point}\n") |
| 1890 | + |
| 1891 | + rvs_to_check = list(self.basic_RVs) |
| 1892 | + if fn in ("logp", "dlogp"): |
| 1893 | + rvs_to_check += [self.replace_rvs_by_values(p) for p in self.potentials] |
| 1894 | + |
| 1895 | + found_problem = False |
| 1896 | + for rv in rvs_to_check: |
| 1897 | + if fn == "logp": |
| 1898 | + rv_fn = pytensor.function( |
| 1899 | + self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore" |
| 1900 | + ) |
| 1901 | + elif fn == "dlogp": |
| 1902 | + rv_fn = pytensor.function( |
| 1903 | + self.value_vars, self.dlogp(vars=rv), on_unused_input="ignore" |
| 1904 | + ) |
| 1905 | + else: |
| 1906 | + [rv_inputs_replaced] = replace_rvs_by_values( |
| 1907 | + [rv], |
| 1908 | + # Don't include itself, or the function will just the the value variable |
| 1909 | + rvs_to_values={ |
| 1910 | + rv_key: value |
| 1911 | + for rv_key, value in self.rvs_to_values.items() |
| 1912 | + if rv_key is not rv |
| 1913 | + }, |
| 1914 | + rvs_to_transforms=self.rvs_to_transforms, |
| 1915 | + ) |
| 1916 | + rv_fn = pytensor.function( |
| 1917 | + self.value_vars, rv_inputs_replaced, on_unused_input="ignore" |
| 1918 | + ) |
| 1919 | + |
| 1920 | + try: |
| 1921 | + rv_fn_eval = rv_fn(**point) |
| 1922 | + except ParameterValueError as exc: |
| 1923 | + found_problem = True |
| 1924 | + debug_parameters(rv) |
| 1925 | + print_( |
| 1926 | + f"This does not respect one of the following constraints: {first_line(exc)}\n" |
| 1927 | + ) |
| 1928 | + if verbose: |
| 1929 | + print_(exc) |
| 1930 | + except Exception as exc: |
| 1931 | + found_problem = True |
| 1932 | + debug_parameters(rv) |
| 1933 | + print_( |
| 1934 | + f"The variable {rv} {fn} method raised the following exception: {first_line(exc)}\n" |
| 1935 | + ) |
| 1936 | + if verbose: |
| 1937 | + print_(exc) |
| 1938 | + else: |
| 1939 | + if not np.all(np.isfinite(rv_fn_eval)): |
| 1940 | + found_problem = True |
| 1941 | + debug_parameters(rv) |
| 1942 | + if fn == "random" or rv is self.potentials: |
| 1943 | + print_("This combination seems able to generate non-finite values") |
| 1944 | + else: |
| 1945 | + # Find which values are associated with non-finite evaluation |
| 1946 | + values = self.rvs_to_values[rv] |
| 1947 | + if rv in self.observed_RVs: |
| 1948 | + values = values.eval() |
| 1949 | + else: |
| 1950 | + values = point[values.name] |
| 1951 | + |
| 1952 | + observed = " observed " if rv in self.observed_RVs else " " |
| 1953 | + print_( |
| 1954 | + f"Some of the{observed}values of variable {rv} are associated with a non-finite {fn}:" |
| 1955 | + ) |
| 1956 | + mask = ~np.isfinite(rv_fn_eval) |
| 1957 | + for value, fn_eval in zip(values[mask], rv_fn_eval[mask]): |
| 1958 | + print_(f" value = {value} -> {fn} = {fn_eval}") |
| 1959 | + print_() |
| 1960 | + |
| 1961 | + if not found_problem: |
| 1962 | + print_("No problems found") |
| 1963 | + elif not verbose: |
| 1964 | + print_("You can set `verbose=True` for more details") |
| 1965 | + |
1814 | 1966 |
|
1815 | 1967 | # this is really disgusting, but it breaks a self-loop: I can't pass Model
|
1816 | 1968 | # itself as context class init arg.
|
|
0 commit comments