Skip to content

BUG: local_IncSubtensor_serialize rewriter breaks on inconsistent shape information #160

Closed
@michaelosthege

Description

@michaelosthege

Describe the issue:

Under some circumstances, when set_subtensor and scan are involved, one can run into an AssertionError during graph rewriting:

Reproducable code example:

import pymc as pm
import pytensor
import pytensor.tensor as pt


def _scan_fn(dt, k, Stminus1_Ptminus1, S0):
    """Inner functions for scans.

    Note that this implementation can broadcast.
    """
    Stminus1 = Stminus1_Ptminus1[0]
    Ptminus1 = Stminus1_Ptminus1[1]
    deltaP = Stminus1 * (1 - pt.exp(-k * dt))
    Pt = Ptminus1 + deltaP
    St = S0 - Pt
    return pt.stack([St, Pt])


def test_issue():

    with pm.Model() as pmodel:
        time = pt.as_tensor([2., 2., 5., 5., 4., 4., 2., 4., 5.])
        time_delay = pm.HalfNormal("time_delay")
        time_actual = pm.Deterministic("time_actual", time + time_delay)

        matrix = pt.zeros((3, 3))
        matrix = pt.set_subtensor(matrix[0, : 3], time_actual[[6, 5, 2]])
        matrix = pt.set_subtensor(matrix[1, : 3], time_actual[[1, 7, 8]])
        matrix = pt.set_subtensor(matrix[2, : 3], time_actual[[0, 4, 3]])
        mat_time = matrix
        time = pt.as_tensor(mat_time, dtype="float64").T
        assert time.eval().shape == (3, 3)

        t0 = pt.zeros([1, 3])
        t0time = pt.concatenate([t0, time], axis=0)
        Dt = pt.diff(t0time, axis=0)
        assert Dt.eval().shape == (3, 3)

        y0 = pt.as_tensor([
            [1, 1, 1],
            [0, 0, 0],
        ], dtype="float64")
        S0 = pt.as_tensor(1, dtype="float64")
        k = pt.as_tensor([
            [2.12733265, 2.12034903, 1.08119962],
            [1.05894315, 1.39664043, 0.63348961],
            [0.79245487, 2.74362656, 1.0872284 ],
        ], dtype="float64").T
        # Shapes of all scan inputs
        assert y0.eval().shape == (2, 3)
        assert S0.eval().shape == ()
        assert k.eval().shape == (3, 3)
        outputs, _ = pytensor.scan(
            fn=_scan_fn,
            sequences=[Dt, k],
            outputs_info={
                "initial": y0,
                "taps": [-1],
            },
            non_sequences=[S0],
            n_steps=3,
            strict=True,
        )
        P = outputs[:, 1].flatten()

        obs = pt.as_tensor([3, 3, 4, 4, 3, 3, 3, 3, 4], dtype="float64")
        pm.Normal("L", P, 1, observed=obs)

    with pmodel:
        idata = pm.sample(
            chains=2,
            tune=3,
            draws=5,
            compute_convergence_checks=False,
        )


if __name__ == "__main__":
    test_issue()

Error message:

# stack up the new incsubtensors
                tip = new_add
                for mi in movable_inputs:
                    assert o_type.is_super(tip.type)
>                   assert mi.owner.inputs[0].type.is_super(tip.type), f"{mi.owner.inputs[0].type} of {mi.owner.inputs[0]} is not super of {tip.type} of {tip}"
E                   AssertionError: TensorType(float64, (9,)) of Elemwise{second,no_inplace}.0 is not super of TensorType(float64, (?,)) of AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=False}.0

pytensor\tensor\rewriting\subtensor.py:1197: AssertionError
---------------------------------------------------------------------------------------------------------------------------------- Captured log call ---------------------------------------------------------------------------------------------------------------------------------- 
WARNING  pymc:mcmc.py:433 Only 5 samples in chain.
ERROR    pytensor.graph.rewriting.basic:basic.py:1768 Rewrite failure due to: local_IncSubtensor_serialize
ERROR    pytensor.graph.rewriting.basic:basic.py:1769 node: Elemwise{add,no_inplace}(AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=False}.0, AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=False}.0)
ERROR    pytensor.graph.rewriting.basic:basic.py:1770 TRACEBACK:
ERROR    pytensor.graph.rewriting.basic:basic.py:1771 Traceback (most recent call last):
  File "E:\Source\Repos\pytensor\pytensor\graph\rewriting\basic.py", line 1933, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Source\Repos\pytensor\pytensor\graph\rewriting\basic.py", line 1092, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "E:\Source\Repos\pytensor\pytensor\tensor\rewriting\subtensor.py", line 1197, in local_IncSubtensor_serialize
    assert mi.owner.inputs[0].type.is_super(tip.type), f"{mi.owner.inputs[0].type} of {mi.owner.inputs[0]} is not super of {tip.type} of {tip}"
AssertionError: TensorType(float64, (9,)) of Elemwise{second,no_inplace}.0 is not super of TensorType(float64, (?,)) of AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=False}.0

PyTensor version information:

2.8.11

Context for the issue:

@fonnesbeck reported the same failure in Slack.

I managed to reduce my model to a (still complicated) MRE that triggers the issue.

The MRE itself is actually quite brittle:

  • Replacing time_delay by 0 creates a compilation error.
  • P.eval() creates a compilation error.
  • Replacing time_actual by a 3x3 pm.HalfNormal creates a compilation error.
  • Commenting out the assert line creates a compilation error.

If anybody can come up with a more compact MRW that would be great!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions