Closed
Description
Describe the issue:
Under some circumstances, when set_subtensor
and scan
are involved, one can run into an AssertionError
during graph rewriting:
Reproducable code example:
import pymc as pm
import pytensor
import pytensor.tensor as pt
def _scan_fn(dt, k, Stminus1_Ptminus1, S0):
"""Inner functions for scans.
Note that this implementation can broadcast.
"""
Stminus1 = Stminus1_Ptminus1[0]
Ptminus1 = Stminus1_Ptminus1[1]
deltaP = Stminus1 * (1 - pt.exp(-k * dt))
Pt = Ptminus1 + deltaP
St = S0 - Pt
return pt.stack([St, Pt])
def test_issue():
with pm.Model() as pmodel:
time = pt.as_tensor([2., 2., 5., 5., 4., 4., 2., 4., 5.])
time_delay = pm.HalfNormal("time_delay")
time_actual = pm.Deterministic("time_actual", time + time_delay)
matrix = pt.zeros((3, 3))
matrix = pt.set_subtensor(matrix[0, : 3], time_actual[[6, 5, 2]])
matrix = pt.set_subtensor(matrix[1, : 3], time_actual[[1, 7, 8]])
matrix = pt.set_subtensor(matrix[2, : 3], time_actual[[0, 4, 3]])
mat_time = matrix
time = pt.as_tensor(mat_time, dtype="float64").T
assert time.eval().shape == (3, 3)
t0 = pt.zeros([1, 3])
t0time = pt.concatenate([t0, time], axis=0)
Dt = pt.diff(t0time, axis=0)
assert Dt.eval().shape == (3, 3)
y0 = pt.as_tensor([
[1, 1, 1],
[0, 0, 0],
], dtype="float64")
S0 = pt.as_tensor(1, dtype="float64")
k = pt.as_tensor([
[2.12733265, 2.12034903, 1.08119962],
[1.05894315, 1.39664043, 0.63348961],
[0.79245487, 2.74362656, 1.0872284 ],
], dtype="float64").T
# Shapes of all scan inputs
assert y0.eval().shape == (2, 3)
assert S0.eval().shape == ()
assert k.eval().shape == (3, 3)
outputs, _ = pytensor.scan(
fn=_scan_fn,
sequences=[Dt, k],
outputs_info={
"initial": y0,
"taps": [-1],
},
non_sequences=[S0],
n_steps=3,
strict=True,
)
P = outputs[:, 1].flatten()
obs = pt.as_tensor([3, 3, 4, 4, 3, 3, 3, 3, 4], dtype="float64")
pm.Normal("L", P, 1, observed=obs)
with pmodel:
idata = pm.sample(
chains=2,
tune=3,
draws=5,
compute_convergence_checks=False,
)
if __name__ == "__main__":
test_issue()
Error message:
# stack up the new incsubtensors
tip = new_add
for mi in movable_inputs:
assert o_type.is_super(tip.type)
> assert mi.owner.inputs[0].type.is_super(tip.type), f"{mi.owner.inputs[0].type} of {mi.owner.inputs[0]} is not super of {tip.type} of {tip}"
E AssertionError: TensorType(float64, (9,)) of Elemwise{second,no_inplace}.0 is not super of TensorType(float64, (?,)) of AdvancedIncSubtensor{inplace=False, set_instead_of_inc=False}.0
pytensor\tensor\rewriting\subtensor.py:1197: AssertionError
---------------------------------------------------------------------------------------------------------------------------------- Captured log call ----------------------------------------------------------------------------------------------------------------------------------
WARNING pymc:mcmc.py:433 Only 5 samples in chain.
ERROR pytensor.graph.rewriting.basic:basic.py:1768 Rewrite failure due to: local_IncSubtensor_serialize
ERROR pytensor.graph.rewriting.basic:basic.py:1769 node: Elemwise{add,no_inplace}(AdvancedIncSubtensor{inplace=False, set_instead_of_inc=False}.0, AdvancedIncSubtensor{inplace=False, set_instead_of_inc=False}.0)
ERROR pytensor.graph.rewriting.basic:basic.py:1770 TRACEBACK:
ERROR pytensor.graph.rewriting.basic:basic.py:1771 Traceback (most recent call last):
File "E:\Source\Repos\pytensor\pytensor\graph\rewriting\basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "E:\Source\Repos\pytensor\pytensor\graph\rewriting\basic.py", line 1092, in transform
return self.fn(fgraph, node)
^^^^^^^^^^^^^^^^^^^^^
File "E:\Source\Repos\pytensor\pytensor\tensor\rewriting\subtensor.py", line 1197, in local_IncSubtensor_serialize
assert mi.owner.inputs[0].type.is_super(tip.type), f"{mi.owner.inputs[0].type} of {mi.owner.inputs[0]} is not super of {tip.type} of {tip}"
AssertionError: TensorType(float64, (9,)) of Elemwise{second,no_inplace}.0 is not super of TensorType(float64, (?,)) of AdvancedIncSubtensor{inplace=False, set_instead_of_inc=False}.0
PyTensor version information:
2.8.11
Context for the issue:
@fonnesbeck reported the same failure in Slack.
I managed to reduce my model to a (still complicated) MRE that triggers the issue.
The MRE itself is actually quite brittle:
- Replacing
time_delay
by0
creates a compilation error. P.eval()
creates a compilation error.- Replacing
time_actual
by a 3x3pm.HalfNormal
creates a compilation error. - Commenting out the
assert
line creates a compilation error.
If anybody can come up with a more compact MRW that would be great!