Skip to content

Static type shapes break graph dependency #254

Closed
@ricardoV94

Description

@ricardoV94

Description

In the following example, calling x.shape, simply returns TensorConstant(100), which means the clone_replace will have no effect.

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(100,))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((50,))}, rebuild_strict=False)

print(x.shape)  # TensorConstant{(1,) of 100}
print(y.type)  # TensorType(float64, (100,))
print(new_y.type)  # TensorType(float64, (100,))
print(new_y.eval().shape)  # (100,)

Instead, if we add a single useless None dimension, Shape will return an object that is still linked to the original variable, and the replacement will works as intended (except for the issue discussed in #253).

x = pt.tensor("x", shape=(None, 100,))
y = pt.zeros(x.shape[-1])
new_y = pytensor.clone_replace(y, replace={x: pt.zeros((2, 50,))}, rebuild_strict=False)

print(x.shape)  # Shape.0
print(y.type)  # TensorType(float64, (100,))
print(new_y.type)  # TensorType(float64, (100,))
print(new_y.eval().shape)  # (50,)

I suggest we don't "severe" the graph dependency between shape and original variables until the compilation stage.
This will allow us to keep manipulating graphs, even as our static type inference continues to improve.

Otherwise we will lose much of PyTensor graph manipulation appeal.

The relevant logic is here:

def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
"""Return the shape of `x`."""
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)
x_type = x.type
if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
else:
res = _shape(x)
return res

And in the quite redundant:

def shape(self):
if not any(s is None for s in self.type.shape):
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)
return at.shape(self)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions