diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 5111607a0a..90d054f6ba 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -266,14 +266,24 @@ def clone_with_new_inputs( assert isinstance(inputs, (list, tuple)) remake_node = False new_inputs: List["Variable"] = list(inputs) + + # Some Ops like Alloc require the node to always be rebuilt in non-strict mode + # as the output type depends on the input values and not just their types + output_type_depends_on_input_value = self.op._output_type_depends_on_input_value + for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)): - if curr.type != new.type: + # Check if the input type changed or if the Op has output types that depend on input values + if (curr.type != new.type) or output_type_depends_on_input_value: + # In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one. + # We only need to rebuild a node when the new input has a different, but compatible, type. + # This can happen e.g., when we provide a new input with a more specialized static shape. if strict: new_i = curr.type.filter_variable(new) new_inputs[i] = new_i if curr.type != new_i.type: remake_node = True + # Otherwise, we always rebuild the node else: remake_node = True diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 0cda6f1327..d6785f9246 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -207,6 +207,15 @@ class Op(MetaObject): otypes: Optional[Sequence["Type"]] = None params_type: Optional[ParamsType] = None + _output_type_depends_on_input_value = False + """ + Whether the static output type depends on the inferred value of one of the inputs. + (e.g, via constant folding or static shape inference). + + This information is needed when rebuilding a graph with new inputs, + as nodes with these Ops must be rebuilt even if the input types haven't changed. + """ + def make_node(self, *inputs: Variable) -> Apply: """Construct an `Apply` node that represent the application of this operation to the given inputs. diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 8f5c9181d0..1aa1ad2f0a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1418,6 +1418,8 @@ class Alloc(COp): """ _f16_ok = True + _output_type_depends_on_input_value = True + __props__ = () def make_node(self, value, *shape): @@ -3819,6 +3821,8 @@ def perform(self, node, inputs, outputs): class AllocEmpty(COp): """Implement Alloc on the cpu, but without initializing memory.""" + _output_type_depends_on_input_value = True + __props__ = ("dtype",) params_type = ParamsType(typecode=int32) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index ca57fee85c..9a6dbfa365 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -1561,6 +1561,8 @@ def broadcast_shape_iter( class BroadcastTo(COp): """An `Op` for `numpy.broadcast_to`.""" + _output_type_depends_on_input_value = True + __props__ = () view_map = {0: [0]} diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 1e4e44274f..56210538c0 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -91,6 +91,8 @@ class RandomVariable(Op): """ + _output_type_depends_on_input_value = True + __props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace") default_output = 1 diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 2e273a1dce..6ed10a3e79 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: if not isinstance(x, Variable): x = at.as_tensor_variable(x) - x_type = x.type - - if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape): - res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64) - else: - res = _shape(x) - - return res + return _shape(x) @_get_vector_length.register(Shape) @@ -395,6 +388,7 @@ class SpecifyShape(COp): view_map = {0: [0]} __props__ = () _f16_ok = True + _output_type_depends_on_input_value = True def make_node(self, x, *shape): from pytensor.tensor.basic import get_underlying_scalar_constant_value @@ -594,6 +588,7 @@ class Reshape(COp): view_map = {0: [0]} # output 0 is potentially aliased to inputs [0] _f16_ok = True + _output_type_depends_on_input_value = True check_input = False __props__ = ("ndim",) diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py index 7cfd4cef87..bd07bfbab7 100644 --- a/pytensor/tensor/var.py +++ b/pytensor/tensor/var.py @@ -12,7 +12,7 @@ from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.utils import MetaType from pytensor.scalar import ComplexError, IntegerDivisionError -from pytensor.tensor import _get_vector_length, as_tensor_variable +from pytensor.tensor import _get_vector_length from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import NoneConst @@ -259,9 +259,6 @@ def transpose(self, *axes): @property def shape(self): - if not any(s is None for s in self.type.shape): - return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64) - return at.shape(self) @property diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 21b23b3433..8f362b4e50 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -1477,13 +1477,12 @@ def test_while_scan_taps_and_map(self): f(x0=0, seq=test_seq, n_steps=0) # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. - # If a MissingInputError is raised, it means the rewrite failed [scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan)) _, _, ys_trace, len_zs = scan_node.inputs debug_fn = pytensor.function( - [n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True + [x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True ) - stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200) + stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200) assert stored_ys_steps == 2 assert stored_zs_steps == 1 diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 107d861d65..f68357a98a 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -14,6 +14,7 @@ from pytensor.graph.basic import Constant, Variable, graph_inputs from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.tensor.random.basic import ( bernoulli, @@ -57,7 +58,7 @@ weibull, ) from pytensor.tensor.rewriting.shape import ShapeFeature -from pytensor.tensor.type import iscalar, scalar, tensor +from pytensor.tensor.type import iscalar, scalar, tensor, vector from tests.unittest_tools import create_pytensor_param @@ -1422,3 +1423,19 @@ def test_pickle(): a_unpkl = pickle.loads(a_pkl) assert a_unpkl.owner.op._props() == sample_a.owner.op._props() + + +def test_rebuild(): + x = vector(shape=(50,)) + x_test = np.zeros((50,), dtype=config.floatX) + y = normal(size=x.shape) + assert y.type.shape == (50,) + assert y.shape.eval({x: x_test}) == (50,) + assert y.eval({x: x_test}).shape == (50,) + + x_new = vector(shape=(100,)) + x_new_test = np.zeros((100,), dtype=config.floatX) + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) + assert y_new.type.shape == (100,) + assert y_new.shape.eval({x_new: x_new_test}) == (100,) + assert y_new.eval({x_new: x_new_test}).shape == (100,) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0be24733f1..79703fb761 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -16,6 +16,7 @@ from pytensor.gradient import grad, hessian from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.misc.safe_asarray import _asarray from pytensor.raise_op import Assert from pytensor.scalar import autocast_float, autocast_float_as @@ -818,6 +819,22 @@ def test_full(self): res = pytensor.function([], full_at, mode=self.mode)() assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) + @pytest.mark.parametrize("func", (at.zeros, at.empty)) + def test_rebuild(self, func): + x = vector(shape=(50,)) + x_test = np.zeros((50,), dtype=config.floatX) + y = func(x.shape) + assert y.type.shape == (50,) + assert y.shape.eval({x: x_test}) == (50,) + assert y.eval({x: x_test}).shape == (50,) + + x_new = vector(shape=(100,)) + x_new_test = np.zeros((100,), dtype=config.floatX) + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) + assert y_new.type.shape == (100,) + assert y_new.shape.eval({x_new: x_new_test}) == (100,) + assert y_new.eval({x_new: x_new_test}).shape == (100,) + def test_infer_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): @@ -3506,7 +3523,7 @@ def test_vector(self): def test_scalar(self): x = scalar() y = np.array(7, dtype=config.floatX) - assert y.size == function([], x.size)() + assert y.size == function([x], x.size)(y) def test_shared(self): # NB: we also test higher order tensors at the same time. diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 079688cc68..0723498486 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.basic import Constant, applys_between +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.raise_op import Assert from pytensor.tensor.elemwise import DimShuffle @@ -1399,6 +1400,22 @@ def test_inplace(self): assert advincsub_node.op.inplace is False + def test_rebuild(self): + x = vector(shape=(50,)) + x_test = np.zeros((50,), dtype=config.floatX) + i = 0 + y = broadcast_to(i, x.shape) + assert y.type.shape == (50,) + assert y.shape.eval({x: x_test}) == (50,) + assert y.eval({x: x_test}).shape == (50,) + + x_new = vector(shape=(100,)) + x_new_test = np.zeros((100,), dtype=config.floatX) + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) + assert y_new.type.shape == (100,) + assert y_new.shape.eval({x_new: x_new_test}) == (100,) + assert y_new.eval({x_new: x_new_test}).shape == (100,) + def test_broadcast_arrays(): x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix() diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 7f8901902a..f34cb5c3cd 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -7,7 +7,6 @@ import pytensor from pytensor import function from pytensor.configdefaults import config -from pytensor.graph.basic import Constant from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import ( SVD, @@ -274,9 +273,7 @@ def test_det_grad(): def test_det_shape(): x = matrix() - det_shape = det(x).shape - assert isinstance(det_shape, Constant) - assert tuple(det_shape.data) == () + assert det(x).type.shape == () def test_slogdet(): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 4fb5f61041..8d8e79511b 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -7,6 +7,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import Variable from pytensor.graph.fg import FunctionGraph +from pytensor.graph.replace import clone_replace from pytensor.graph.type import Type from pytensor.misc.safe_asarray import _asarray from pytensor.scalar.basic import ScalarConstant @@ -16,6 +17,7 @@ from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.shape import ( Reshape, + Shape, Shape_i, SpecifyShape, Unbroadcast, @@ -336,6 +338,21 @@ def test_more_shapes(self): Reshape, ) + def test_rebuild(self): + x = as_tensor_variable(50) + i = vector("i") + i_test = np.zeros((100,), dtype=config.floatX) + y = reshape(i, (100 // x, x)) + assert y.type.shape == (2, 50) + assert tuple(y.shape.eval({i: i_test})) == (2, 50) + assert y.eval({i: i_test}).shape == (2, 50) + + x_new = as_tensor_variable(25) + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) + assert y_new.type.shape == (4, 25) + assert tuple(y_new.shape.eval({i: i_test})) == (4, 25) + assert y_new.eval({i: i_test}).shape == (4, 25) + def test_shape_i_hash(): assert isinstance(Shape_i(np.int64(1)).__hash__(), int) @@ -397,7 +414,7 @@ def test_fixed_shapes(self): shape = as_tensor_variable([2]) y = specify_shape(x, shape) assert y.type.shape == (2,) - assert y.shape.equals(shape) + assert isinstance(y.shape.owner.op, Shape) def test_fixed_partial_shapes(self): x = TensorType("floatX", (None, None))("x") @@ -523,6 +540,22 @@ def test_specify_shape_in_grad(self): z_grad = grad(z.sum(), wrt=x) assert isinstance(z_grad.owner.op, SpecifyShape) + def test_rebuild(self): + x = as_tensor_variable(50) + i = matrix("i") + i_test = np.zeros((4, 50), dtype=config.floatX) + y = specify_shape(i, (None, x)) + assert y.type.shape == (None, 50) + assert tuple(y.shape.eval({i: i_test})) == (4, 50) + assert y.eval({i: i_test}).shape == (4, 50) + + x_new = as_tensor_variable(100) + i_test = np.zeros((4, 100), dtype=config.floatX) + y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) + assert y_new.type.shape == (None, 100) + assert tuple(y_new.shape.eval({i: i_test})) == (4, 100) + assert y_new.eval({i: i_test}).shape == (4, 100) + class TestSpecifyBroadcastable: def test_basic(self): diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index c17f524797..510b18d56b 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -6,12 +6,14 @@ import pytensor import tests.unittest_tools as utt +from pytensor.compile import DeepCopyOp from pytensor.compile.mode import get_default_mode from pytensor.graph.basic import Constant, equal_computations from pytensor.tensor import get_vector_length from pytensor.tensor.basic import constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import dot, eq +from pytensor.tensor.shape import Shape from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.type import ( TensorType, @@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order): def test_fixed_shape_variable_basic(): x = TensorVariable(TensorType("int64", shape=(4,)), None) - assert isinstance(x.shape, Constant) - assert np.array_equal(x.shape.data, (4,)) + assert x.type.shape == (4,) + assert isinstance(x.shape.owner.op, Shape) + + shape_fn = pytensor.function([x], x.shape) + opt_shape = shape_fn.maker.fgraph.outputs[0] + assert isinstance(opt_shape.owner.op, DeepCopyOp) + assert isinstance(opt_shape.owner.inputs[0], Constant) + assert np.array_equal(opt_shape.owner.inputs[0].data, (4,)) x = TensorConstant( TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])