Description
Description
In the following example, calling x.shape
, simply returns TensorConstant(100)
, which means the clone_replace
will have no effect.
import pytensor
import pytensor.tensor as pt
x = pt.tensor("x", shape=(100,))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((50,))}, rebuild_strict=False)
print(x.shape) # TensorConstant{(1,) of 100}
print(y.type) # TensorType(float64, (100,))
print(new_y.type) # TensorType(float64, (100,))
print(new_y.eval().shape) # (100,)
Instead, if we add a single useless None
dimension, Shape
will return an object that is still linked to the original variable, and the replacement will works as intended (except for the issue discussed in #253).
x = pt.tensor("x", shape=(None, 100,))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((2, 50,))}, rebuild_strict=False)
print(x.shape) # Shape.0
print(y.type) # TensorType(float64, (100,))
print(new_y.type) # TensorType(float64, (100,))
print(new_y.eval().shape) # (50,)
I suggest we don't "severe" the graph dependency between shape and original variables until the compilation stage.
This will allow us to keep manipulating graphs, even as our static type inference continues to improve.
Otherwise we will lose much of PyTensor graph manipulation appeal.
The relevant logic is here:
pytensor/pytensor/tensor/shape.py
Lines 144 to 156 in 8606498
And in the quite redundant:
pytensor/pytensor/tensor/var.py
Lines 261 to 265 in 8606498