From 25397c82d6bd082e1b73a6aa4a4addeb1339ad7b Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 21 May 2025 19:11:02 +0200 Subject: [PATCH 01/18] WIP Implement index operations for XTensorVariables --- pytensor/xtensor/__init__.py | 1 - pytensor/xtensor/indexing.py | 142 +++++++++++++++++++++++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/indexing.py | 67 ++++++++++++ pytensor/xtensor/type.py | 111 ++++++++++++++++++- tests/xtensor/test_indexing.py | 42 ++++++++ 6 files changed, 361 insertions(+), 3 deletions(-) create mode 100644 pytensor/xtensor/indexing.py create mode 100644 pytensor/xtensor/rewriting/indexing.py create mode 100644 tests/xtensor/test_indexing.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index a72bf66c79..06265e40de 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -7,7 +7,6 @@ ) from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( - XTensorType, as_xtensor, xtensor, xtensor_constant, diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..9c25d9ed1a --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,142 @@ +# HERE LIE DRAGONS +# Uselful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + else: + # Must be integer indices, we already counted for None and slices + try: + idx = as_tensor(idx) + except TypeError: + idx = as_xtensor(idx) + if idx.type.dtype == "bool": + raise NotImplementedError("Boolean indexing not yet supported") + if idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + if idx.type.dtype == "bool" and idx.type.ndim == 0: + # This can't be triggered right now, but will once we lift the boolean restriction + raise NotImplementedError("Scalar boolean indices not supported") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + idxs = [as_idx_variable(idx) for idx in idxs] + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + has_unlabeled_vector_idx = False + has_labeled_vector_idx = False + for i, idx in enumerate(idxs): + if i == x_ndim: + raise IndexError("Too many indices") + if isinstance(idx.type, SliceType): + out_dims.append(x_dims[i]) + out_shape.append(get_static_slice_length(idx, x_shape[i])) + elif isinstance(idx.type, XTensorType): + if has_unlabeled_vector_idx: + raise NotImplementedError( + "Mixing of labeled and unlabeled vector indexing not implemented" + ) + has_labeled_vector_idx = True + idx_dims = idx.type.dims + for dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(dim)] + if dim in out_dims: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif ( + idx_dim_shape is not None and idx_dim_shape != out_dim_shape + ): + raise IndexError( + f"Dimension of indexers mismatch for dim {dim}" + ) + else: + # New dimension + out_dims.append(dim) + out_shape.append(idx_dim_shape) + + else: # TensorType + if idx.type.ndim == 0: + # Scalar, dimension is dropped + pass + elif idx.type.ndim == 1: + if has_labeled_vector_idx: + raise NotImplementedError( + "Mixing of labeled and unlabeled vector indexing not implemented" + ) + has_unlabeled_vector_idx = True + out_dims.append(x_dims[i]) + out_shape.append(idx.type.shape[0]) + else: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + for j in range(i + 1, x_ndim): + # Add any unindexed dimensions + out_dims.append(x_dims[j]) + out_shape.append(x_shape[j]) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 7ce55b9256..a65ad0db85 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,4 +1,5 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..dea99543a3 --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,67 @@ +from pytensor.graph import Constant, node_rewriter +from pytensor.tensor import TensorType, specify_shape +from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index +from pytensor.xtensor.rewriting.utils import register_xcanonicalize +from pytensor.xtensor.type import XTensorType + + +def to_basic_idx(idx): + if isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + return idx.data + elif idx.owner: + # MakeSlice Op + # We transform NoneConsts to regular None so that basic Subtensor can be used if possible + return slice( + *[ + None if isinstance(i.type, NoneTypeT) else i + for i in idx.owner.inputs + ] + ) + else: + return idx + if ( + isinstance(idx.type, XTensorType | TensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ): + return idx + raise TypeError("Cannot convert idx to basic idx") + + +def _count_idx_types(idxs): + basic, vector, xvector = 0, 0, 0 + for idx in idxs: + if isinstance(idx.type, SliceType): + basic += 1 + elif idx.type.ndim == 0: + basic += 1 + elif isinstance(idx.type, TensorType): + vector += 1 + else: + xvector += 1 + return basic, vector, xvector + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + x, *idxs = node.inputs + [out] = node.outputs + x_tensor = tensor_from_xtensor(x) + n_basic, n_vector, n_xvector = _count_idx_types(idxs) + if n_xvector == 0 and n_vector == 0: + x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] + elif n_vector == 1 and n_xvector == 0: + # Special case for single vector index, no orthogonal indexing + x_tensor_indexed = x_tensor[tuple(idxs)] + else: + # Not yet implemented + return None + + # Add lost shape if any + x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 5b79e9ae57..804f79cbaf 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,3 +1,5 @@ +import warnings + from pytensor.tensor import TensorType from pytensor.tensor.math import variadic_mul @@ -10,7 +12,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Any, Literal, TypeVar import numpy as np @@ -339,7 +341,112 @@ def sel(self, *args, **kwargs): raise NotImplementedError("sel not implemented for XTensorVariable") def __getitem__(self, idx): - raise NotImplementedError("Indexing not yet implemnented") + if isinstance(idx, dict): + return self.isel(idx) + + # Check for ellipsis not in the last position (last one is useless anyway) + if any(idx_item is Ellipsis for idx_item in idx): + if idx.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idx.index(Ellipsis) + n_implied_none_slices = self.type.ndim - (len(idx) - 1) + idx = ( + *idx[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idx[ellipsis_loc + 1 :], + ) + + return px.indexing.index(self, *idx) + + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs + + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) + + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + UserWarning, + f"Dimension {key} does not exist. Expected one of {dims}", + ) + + return px.indexing.index(self, *indices) + + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..6c83036931 --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +from xarray import DataArray + +from pytensor.xtensor import xtensor +from tests.xtensor.util import xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) From e32d86516a8fef9d52c0c0925ba0af5622ba1f64 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 26 May 2025 17:38:05 +0200 Subject: [PATCH 02/18] Add diff method to XTensorVariable --- pytensor/xtensor/type.py | 9 +++++++++ tests/xtensor/test_indexing.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 804f79cbaf..877c61d2b6 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -499,6 +499,15 @@ def cumsum(self, dim): def cumprod(self, dim): return px.reduction.cumprod(self, dim) + def diff(self, dim, n=1): + """Compute the n-th discrete difference along the given dimension.""" + slice1 = {dim: slice(1, None)} + slice2 = {dim: slice(None, -1)} + x = self + for _ in range(n): + x = x[slice1] - x[slice2] + return x + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index 6c83036931..19d16d1ec5 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -1,6 +1,7 @@ import numpy as np import pytest from xarray import DataArray +from xtensor.util import xr_arange_like from pytensor.xtensor import xtensor from tests.xtensor.util import xr_assert_allclose, xr_function @@ -40,3 +41,22 @@ def test_basic_indexing(labeled, indices): res = fn(x_test) expected_res = x_test[indices] xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("n", ["implicit", 1, 2]) +@pytest.mark.parametrize("dim", ["a", "b"]) +def test_diff(dim, n): + x = xtensor(dims=("a", "b"), shape=(7, 11)) + if n == "implicit": + out = x.diff(dim) + else: + out = x.diff(dim, n=n) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + if n == "implicit": + expected_res = x_test.diff(dim) + else: + expected_res = x_test.diff(dim, n=n) + xr_assert_allclose(res, expected_res) From 5988cec91ebd37dffdeeb8ce1fbfa28cd7bb4dc4 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 27 May 2025 13:18:46 -0400 Subject: [PATCH 03/18] Add transpose operation for labeled tensors with ellipsis support --- pytensor/xtensor/rewriting/shape.py | 34 ++++++++++++++++++++++- pytensor/xtensor/shape.py | 43 +++++++++++++++++++++++++++++ tests/xtensor/test_shape.py | 5 ++-- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 06b8c40a32..80f1750598 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack +from pytensor.xtensor.shape import Concat, Stack, Transpose @register_xcanonicalize @@ -70,3 +70,35 @@ def lower_concat(fgraph, node): joined_tensor = join(concat_axis, *bcast_tensor_inputs) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Determine the permutation of axes + out_dims = node.op.dims + in_dims = x.type.dims + # Expand ellipsis if present + if out_dims == () or out_dims == (...,): + expanded_dims = tuple(reversed(in_dims)) + elif ... in out_dims: + pre = [] + post = [] + found = False + for d in out_dims: + if d is ...: + found = True + elif not found: + pre.append(d) + else: + post.append(d) + middle = [d for d in in_dims if d not in pre + post] + expanded_dims = tuple(pre + middle + post) + else: + expanded_dims = out_dims + perm = tuple(in_dims.index(d) for d in expanded_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=expanded_dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f39d495285..ff44ae0503 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -73,6 +73,49 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +class Transpose(XOp): + __props__ = ("dims",) + + def __init__(self, dims: tuple[str, ...]): + super().__init__() + self.dims = dims + + def make_node(self, x): + x = as_xtensor(x) + # Allow ellipsis for full transpose + if self.dims == () or self.dims == (...,): + dims = tuple(reversed(x.type.dims)) + else: + # Expand ellipsis if present + if ... in self.dims: + pre = [] + post = [] + found = False + for d in self.dims: + if d is ...: + found = True + elif not found: + pre.append(d) + else: + post.append(d) + middle = [d for d in x.type.dims if d not in pre + post] + dims = tuple(pre + middle + post) + else: + dims = self.dims + if set(dims) != set(x.type.dims): + raise ValueError(f"Transpose dims {dims} must match {x.type.dims}") + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims), + dims=dims, + ) + return Apply(self, [x], [output]) + + +def transpose(x, *dims): + return Transpose(dims)(x) + + class Concat(XOp): __props__ = ("dim",) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 79cc2738a2..e21c67b843 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -10,7 +10,7 @@ from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack +from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.type import xtensor from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like @@ -24,9 +24,8 @@ def powerset(iterable, min_group_size=0): ) -@pytest.mark.xfail(reason="Not yet implemented") +# @pytest.mark.xfail(reason="Not yet implemented") def test_transpose(): - transpose = None a, b, c, d, e = "abcde" x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) From 5936ab2492ab0d4c5413801183bbfc9a958edfa5 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 27 May 2025 13:35:14 -0400 Subject: [PATCH 04/18] Refactor: Extract ellipsis expansion logic into helper function --- pytensor/xtensor/rewriting/shape.py | 21 ++--------- pytensor/xtensor/shape.py | 56 ++++++++++++++++++----------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 80f1750598..2811d5705f 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose +from pytensor.xtensor.shape import Concat, Stack, Transpose, expand_ellipsis @register_xcanonicalize @@ -79,24 +79,7 @@ def lower_transpose(fgraph, node): # Determine the permutation of axes out_dims = node.op.dims in_dims = x.type.dims - # Expand ellipsis if present - if out_dims == () or out_dims == (...,): - expanded_dims = tuple(reversed(in_dims)) - elif ... in out_dims: - pre = [] - post = [] - found = False - for d in out_dims: - if d is ...: - found = True - elif not found: - pre.append(d) - else: - post.append(d) - middle = [d for d in in_dims if d not in pre + post] - expanded_dims = tuple(pre + middle + post) - else: - expanded_dims = out_dims + expanded_dims = expand_ellipsis(out_dims, in_dims) perm = tuple(in_dims.index(d) for d in expanded_dims) x_tensor = tensor_from_xtensor(x) x_tensor_transposed = x_tensor.transpose(perm) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index ff44ae0503..0fe02c0d83 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -73,6 +73,41 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +def expand_ellipsis(dims: tuple[str, ...], all_dims: tuple[str, ...]) -> tuple[str, ...]: + """Expand ellipsis in dimension permutation. + + Parameters + ---------- + dims : tuple[str, ...] + The dimension permutation, which may contain ellipsis + all_dims : tuple[str, ...] + All available dimensions + + Returns + ------- + tuple[str, ...] + The expanded dimension permutation + """ + if dims == () or dims == (...,): + return tuple(reversed(all_dims)) + + if ... not in dims: + return dims + + pre = [] + post = [] + found = False + for d in dims: + if d is ...: + found = True + elif not found: + pre.append(d) + else: + post.append(d) + middle = [d for d in all_dims if d not in pre + post] + return tuple(pre + middle + post) + + class Transpose(XOp): __props__ = ("dims",) @@ -82,26 +117,7 @@ def __init__(self, dims: tuple[str, ...]): def make_node(self, x): x = as_xtensor(x) - # Allow ellipsis for full transpose - if self.dims == () or self.dims == (...,): - dims = tuple(reversed(x.type.dims)) - else: - # Expand ellipsis if present - if ... in self.dims: - pre = [] - post = [] - found = False - for d in self.dims: - if d is ...: - found = True - elif not found: - pre.append(d) - else: - post.append(d) - middle = [d for d in x.type.dims if d not in pre + post] - dims = tuple(pre + middle + post) - else: - dims = self.dims + dims = expand_ellipsis(self.dims, x.type.dims) if set(dims) != set(x.type.dims): raise ValueError(f"Transpose dims {dims} must match {x.type.dims}") output = xtensor( From 6fc7b89895c5d780b4fdf7c28ad6d7ae70ddaa13 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 27 May 2025 15:15:04 -0400 Subject: [PATCH 05/18] Fix lint errors: remove trailing whitespace from docstrings --- pytensor/xtensor/shape.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 0fe02c0d83..01748d0ac6 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -75,14 +75,14 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) def expand_ellipsis(dims: tuple[str, ...], all_dims: tuple[str, ...]) -> tuple[str, ...]: """Expand ellipsis in dimension permutation. - + Parameters ---------- dims : tuple[str, ...] The dimension permutation, which may contain ellipsis all_dims : tuple[str, ...] All available dimensions - + Returns ------- tuple[str, ...] @@ -90,10 +90,10 @@ def expand_ellipsis(dims: tuple[str, ...], all_dims: tuple[str, ...]) -> tuple[s """ if dims == () or dims == (...,): return tuple(reversed(all_dims)) - + if ... not in dims: return dims - + pre = [] post = [] found = False From 0778cf79a04ee60b1b64a395ef8f9aec2da454db Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 27 May 2025 15:22:44 -0400 Subject: [PATCH 06/18] Format files with ruff --- pytensor/xtensor/shape.py | 4 +++- pytensor/xtensor/type.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 01748d0ac6..6f3f148020 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -73,7 +73,9 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y -def expand_ellipsis(dims: tuple[str, ...], all_dims: tuple[str, ...]) -> tuple[str, ...]: +def expand_ellipsis( + dims: tuple[str, ...], all_dims: tuple[str, ...] +) -> tuple[str, ...]: """Expand ellipsis in dimension permutation. Parameters diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 877c61d2b6..6983542ab9 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -586,8 +586,7 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None): if isinstance(x, Apply): if len(x.outputs) != 1: raise ValueError( - "It is ambiguous which output of a " - "multi-output Op has to be fetched.", + "It is ambiguous which output of a multi-output Op has to be fetched.", x, ) else: From c7ce0c93d5380f9b4762b62d32a9738e88398b16 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Tue, 27 May 2025 15:28:11 -0400 Subject: [PATCH 07/18] Remove commented out line --- tests/xtensor/test_shape.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index e21c67b843..7bd938e0f2 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -24,7 +24,6 @@ def powerset(iterable, min_group_size=0): ) -# @pytest.mark.xfail(reason="Not yet implemented") def test_transpose(): a, b, c, d, e = "abcde" From bc2cbc02c81f993ae7dbbd5515930f8994d6ebc4 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 10:20:10 -0400 Subject: [PATCH 08/18] Add missing_dims parameter to transpose for XTensorVariable and core, matching xarray behavior; update tests and rewrites accordingly --- pytensor/xtensor/rewriting/shape.py | 35 ++++++++- pytensor/xtensor/shape.py | 116 ++++++++++++++++++++++++++++ pytensor/xtensor/type.py | 42 +++++++++- tests/xtensor/test_shape.py | 72 ++++++++++++++++- 4 files changed, 260 insertions(+), 5 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 06b8c40a32..bcce1c7c5e 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,8 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack +from pytensor.xtensor.shape import Concat, Stack, Transpose, expand_ellipsis +import warnings @register_xcanonicalize @@ -70,3 +71,35 @@ def lower_concat(fgraph, node): joined_tensor = join(concat_axis, *bcast_tensor_inputs) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Transpose]) +def lower_transpose(fgraph, node): + [x] = node.inputs + # Determine the permutation of axes + out_dims = node.op.dims + in_dims = x.type.dims + expanded_dims = expand_ellipsis(out_dims, in_dims) + + # Handle missing dimensions based on missing_dims setting + if node.op.missing_dims == "ignore": + # Filter out dimensions that don't exist in in_dims + expanded_dims = tuple(d for d in expanded_dims if d in in_dims) + # Add remaining dimensions in their original order + remaining_dims = tuple(d for d in in_dims if d not in expanded_dims) + expanded_dims = expanded_dims + remaining_dims + elif node.op.missing_dims == "warn": + missing = set(expanded_dims) - set(in_dims) + if missing: + warnings.warn(f"Dimensions {missing} do not exist in {in_dims}") + # Filter out missing dimensions and add remaining ones + expanded_dims = tuple(d for d in expanded_dims if d in in_dims) + remaining_dims = tuple(d for d in in_dims if d not in expanded_dims) + expanded_dims = expanded_dims + remaining_dims + + perm = tuple(in_dims.index(d) for d in expanded_dims) + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = x_tensor.transpose(perm) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=expanded_dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f39d495285..0029d35880 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,4 +1,6 @@ from collections.abc import Sequence +from typing import Literal +import warnings from pytensor import Variable from pytensor.graph import Apply @@ -73,6 +75,120 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) return y +def expand_ellipsis( + dims: tuple[str, ...], all_dims: tuple[str, ...] +) -> tuple[str, ...]: + """Expand ellipsis in dimension permutation. + + Parameters + ---------- + dims : tuple[str, ...] + The dimension permutation, which may contain ellipsis + all_dims : tuple[str, ...] + All available dimensions + + Returns + ------- + tuple[str, ...] + The expanded dimension permutation + + Raises + ------ + ValueError + If more than one ellipsis is present in dims. + """ + if dims == () or dims == (...,): + return tuple(reversed(all_dims)) + + if ... not in dims: + return dims + + if sum(d is ... for d in dims) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + + pre = [] + post = [] + found = False + for d in dims: + if d is ...: + found = True + elif not found: + pre.append(d) + else: + post.append(d) + middle = [d for d in all_dims if d not in pre + post] + return tuple(pre + middle + post) + + +class Transpose(XOp): + __props__ = ("dims", "missing_dims") + + def __init__(self, dims: tuple[str, ...], missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + super().__init__() + self.dims = dims + self.missing_dims = missing_dims + + def make_node(self, x): + x = as_xtensor(x) + dims = expand_ellipsis(self.dims, x.type.dims) + + # Handle missing dimensions based on missing_dims setting + if self.missing_dims == "ignore": + # Filter out dimensions that don't exist in x.type.dims + dims = tuple(d for d in dims if d in x.type.dims) + # Add remaining dimensions in their original order + remaining_dims = tuple(d for d in x.type.dims if d not in dims) + dims = dims + remaining_dims + elif self.missing_dims == "warn": + missing = set(dims) - set(x.type.dims) + if missing: + warnings.warn(f"Dimensions {missing} do not exist in {x.type.dims}") + # Filter out missing dimensions and add remaining ones + dims = tuple(d for d in dims if d in x.type.dims) + remaining_dims = tuple(d for d in x.type.dims if d not in dims) + dims = dims + remaining_dims + else: # "raise" + if set(dims) != set(x.type.dims): + raise ValueError(f"Transpose dims {dims} must match {x.type.dims}") + + output = xtensor( + dtype=x.type.dtype, + shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims), + dims=dims, + ) + return Apply(self, [x], [output]) + + +def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + """Transpose dimensions of the tensor. + + Parameters + ---------- + x : XTensorVariable + The tensor to transpose + *dims : str | Ellipsis + Dimensions to transpose. If empty, performs a full transpose. + Can use ellipsis (...) to represent remaining dimensions. + missing_dims : {"raise", "warn", "ignore"}, default="raise" + How to handle dimensions that don't exist in the tensor: + - "raise": Raise an error if any dimensions don't exist + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If missing_dims="raise" and any dimensions don't exist. + If multiple ellipsis are provided. + """ + return Transpose(dims, missing_dims=missing_dims)(x) + + class Concat(XOp): __props__ = ("dim",) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 877c61d2b6..4b62c87110 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -12,7 +12,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, Union import numpy as np @@ -464,6 +464,46 @@ def imag(self): def real(self): return px.math.real(self) + def transpose(self, *dims: Union[str, type(Ellipsis)], missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + """Transpose dimensions of the tensor. + + Parameters + ---------- + *dims : str | Ellipsis + Dimensions to transpose. If empty, performs a full transpose. + Can use ellipsis (...) to represent remaining dimensions. + missing_dims : {"raise", "warn", "ignore"}, default="raise" + How to handle dimensions that don't exist in the tensor: + - "raise": Raise an error if any dimensions don't exist + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist + + Returns + ------- + XTensorVariable + Transposed tensor with reordered dimensions. + + Raises + ------ + ValueError + If missing_dims="raise" and any dimensions don't exist. + If multiple ellipsis are provided. + """ + return px.shape.transpose(self, *dims, missing_dims=missing_dims) + + @property + def T(self) -> "XTensorVariable": + """Return the full transpose of the tensor. + + This is equivalent to calling transpose() with no arguments. + + Returns + ------- + XTensorVariable + Fully transposed tensor. + """ + return self.transpose() + # Aggregation # https://docs.xarray.dev/en/latest/api.html#id6 def all(self, dim): diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 79cc2738a2..071f8339a0 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -10,7 +10,7 @@ from xarray import DataArray from xarray import concat as xr_concat -from pytensor.xtensor.shape import concat, stack +from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.type import xtensor from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like @@ -24,9 +24,7 @@ def powerset(iterable, min_group_size=0): ) -@pytest.mark.xfail(reason="Not yet implemented") def test_transpose(): - transpose = None a, b, c, d, e = "abcde" x = xtensor("x", dims=(a, b, c, d, e), shape=(2, 3, 5, 7, 11)) @@ -163,3 +161,71 @@ def test_concat_scalar(): res = fn(x1_test, x2_test) expected_res = xr_concat([x1_test, x2_test], dim="new_dim") xr_assert_allclose(res, expected_res) + + +def test_xtensor_variable_transpose(): + """Test the transpose() method of XTensorVariable.""" + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + + # Test basic transpose + out = x.transpose() + fn = xr_function([x], out) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test transpose with specific dimensions + out = x.transpose("c", "a", "b") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) + + # Test transpose with ellipsis + out = x.transpose("c", ...) + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test error cases + with pytest.raises(ValueError, match="Transpose dims.*must match"): + x.transpose("d") + + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): + x.transpose("a", ..., "b", ...) + + # Test missing_dims parameter + # Test ignore + out = x.transpose("c", ..., "d", missing_dims="ignore") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + # Test warn + with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): + out = x.transpose("c", ..., "d", missing_dims="warn") + fn = xr_function([x], out) + xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) + + +def test_xtensor_variable_T(): + """Test the T property of XTensorVariable.""" + # Test T property with 3D tensor + x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) + out = x.T + + fn = xr_function([x], out) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + xr_assert_allclose(fn(x_test), x_test.transpose()) + + # Test T property with 2D tensor + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + out = x.T + + fn = xr_function([x], out) + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + ) + xr_assert_allclose(fn(x_test), x_test.transpose()) From d4f5512b1ed7570e584a5b67a924a40409c8a815 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 10:51:02 -0400 Subject: [PATCH 09/18] Fix linting issues: remove unused Union import and use dict.fromkeys() --- pytensor/xtensor/rewriting/shape.py | 31 +++++---------------- pytensor/xtensor/type.py | 42 ++++++++++++++--------------- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index bcce1c7c5e..03deb9a91c 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,8 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose, expand_ellipsis -import warnings +from pytensor.xtensor.shape import Concat, Stack, Transpose @register_xcanonicalize @@ -77,29 +76,13 @@ def lower_concat(fgraph, node): @node_rewriter(tracks=[Transpose]) def lower_transpose(fgraph, node): [x] = node.inputs - # Determine the permutation of axes - out_dims = node.op.dims + # Use the final dimensions that were already computed in make_node + out_dims = node.outputs[0].type.dims in_dims = x.type.dims - expanded_dims = expand_ellipsis(out_dims, in_dims) - - # Handle missing dimensions based on missing_dims setting - if node.op.missing_dims == "ignore": - # Filter out dimensions that don't exist in in_dims - expanded_dims = tuple(d for d in expanded_dims if d in in_dims) - # Add remaining dimensions in their original order - remaining_dims = tuple(d for d in in_dims if d not in expanded_dims) - expanded_dims = expanded_dims + remaining_dims - elif node.op.missing_dims == "warn": - missing = set(expanded_dims) - set(in_dims) - if missing: - warnings.warn(f"Dimensions {missing} do not exist in {in_dims}") - # Filter out missing dimensions and add remaining ones - expanded_dims = tuple(d for d in expanded_dims if d in in_dims) - remaining_dims = tuple(d for d in in_dims if d not in expanded_dims) - expanded_dims = expanded_dims + remaining_dims - - perm = tuple(in_dims.index(d) for d in expanded_dims) + + # Compute the permutation based on the final dimensions + perm = tuple(in_dims.index(d) for d in out_dims) x_tensor = tensor_from_xtensor(x) x_tensor_transposed = x_tensor.transpose(perm) - new_out = xtensor_from_tensor(x_tensor_transposed, dims=expanded_dims) + new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index eec48432ef..812c1c7056 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -12,7 +12,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import Any, Literal, TypeVar, Union +from typing import Any, Literal, TypeVar import numpy as np @@ -422,10 +422,10 @@ def _head_tail_or_thin( ) else: # Default to 5 for head and tail - indexers = {dim: 5 for dim in self.type.dims} + indexers = dict.fromkeys(self.type.dims, 5) elif not isinstance(indexers, dict): - indexers = {dim: indexers for dim in self.type.dims} + indexers = dict.fromkeys(self.type.dims, indexers) if kind == "head": indices = {dim: slice(None, value) for dim, value in indexers.items()} @@ -464,43 +464,41 @@ def imag(self): def real(self): return px.math.real(self) - def transpose(self, *dims: Union[str, type(Ellipsis)], missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + def transpose(self, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"): """Transpose dimensions of the tensor. - + Parameters ---------- - *dims : str | Ellipsis - Dimensions to transpose. If empty, performs a full transpose. - Can use ellipsis (...) to represent remaining dimensions. - missing_dims : {"raise", "warn", "ignore"}, default="raise" - How to handle dimensions that don't exist in the tensor: - - "raise": Raise an error if any dimensions don't exist + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) - "warn": Warn if any dimensions don't exist - "ignore": Silently ignore any dimensions that don't exist - + Returns ------- XTensorVariable Transposed tensor with reordered dimensions. - + Raises ------ ValueError - If missing_dims="raise" and any dimensions don't exist. - If multiple ellipsis are provided. + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". """ - return px.shape.transpose(self, *dims, missing_dims=missing_dims) + from pytensor.xtensor.shape import transpose + return transpose(self, *dims, missing_dims=missing_dims) @property - def T(self) -> "XTensorVariable": - """Return the full transpose of the tensor. - - This is equivalent to calling transpose() with no arguments. - + def T(self): + """Transpose all dimensions of the tensor, reversing their order. + Returns ------- XTensorVariable - Fully transposed tensor. + Transposed tensor with reversed dimensions. """ return self.transpose() From 1ed01c41b1c4093e01442a103e853e1ba6f8c2aa Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 10:52:02 -0400 Subject: [PATCH 10/18] Improve expand_ellipsis with validate parameter and update tests --- pytensor/xtensor/shape.py | 44 +++++++++++++++++++++++-------------- tests/xtensor/test_shape.py | 22 +++++++++---------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 0029d35880..dbcdf5ba61 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -1,6 +1,6 @@ +import warnings from collections.abc import Sequence from typing import Literal -import warnings from pytensor import Variable from pytensor.graph import Apply @@ -76,7 +76,7 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) def expand_ellipsis( - dims: tuple[str, ...], all_dims: tuple[str, ...] + dims: tuple[str, ...], all_dims: tuple[str, ...], validate: bool = True ) -> tuple[str, ...]: """Expand ellipsis in dimension permutation. @@ -86,6 +86,8 @@ def expand_ellipsis( The dimension permutation, which may contain ellipsis all_dims : tuple[str, ...] All available dimensions + validate : bool, default True + Whether to check that all non-ellipsis elements in dims are valid dimension names. Returns ------- @@ -96,11 +98,16 @@ def expand_ellipsis( ------ ValueError If more than one ellipsis is present in dims. + If any non-ellipsis element in dims is not a valid dimension name and validate is True. """ if dims == () or dims == (...,): return tuple(reversed(all_dims)) if ... not in dims: + if validate: + invalid_dims = set(dims) - set(all_dims) + if invalid_dims: + raise ValueError(f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}") return dims if sum(d is ... for d in dims) > 1: @@ -116,6 +123,10 @@ def expand_ellipsis( pre.append(d) else: post.append(d) + if validate: + invalid_dims = set(pre + post) - set(all_dims) + if invalid_dims: + raise ValueError(f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}") middle = [d for d in all_dims if d not in pre + post] return tuple(pre + middle + post) @@ -130,8 +141,8 @@ def __init__(self, dims: tuple[str, ...], missing_dims: Literal["raise", "warn", def make_node(self, x): x = as_xtensor(x) - dims = expand_ellipsis(self.dims, x.type.dims) - + dims = expand_ellipsis(self.dims, x.type.dims, validate=(self.missing_dims == "raise")) + # Handle missing dimensions based on missing_dims setting if self.missing_dims == "ignore": # Filter out dimensions that don't exist in x.type.dims @@ -150,7 +161,7 @@ def make_node(self, x): else: # "raise" if set(dims) != set(x.type.dims): raise ValueError(f"Transpose dims {dims} must match {x.type.dims}") - + output = xtensor( dtype=x.type.dtype, shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims), @@ -161,30 +172,29 @@ def make_node(self, x): def transpose(x, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"): """Transpose dimensions of the tensor. - + Parameters ---------- x : XTensorVariable - The tensor to transpose - *dims : str | Ellipsis - Dimensions to transpose. If empty, performs a full transpose. - Can use ellipsis (...) to represent remaining dimensions. - missing_dims : {"raise", "warn", "ignore"}, default="raise" - How to handle dimensions that don't exist in the tensor: - - "raise": Raise an error if any dimensions don't exist + Input tensor to transpose. + *dims : str + Dimensions to transpose to. Can include ellipsis (...) to represent + remaining dimensions in their original order. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in the input tensor: + - "raise": Raise an error if any dimensions don't exist (default) - "warn": Warn if any dimensions don't exist - "ignore": Silently ignore any dimensions that don't exist - + Returns ------- XTensorVariable Transposed tensor with reordered dimensions. - + Raises ------ ValueError - If missing_dims="raise" and any dimensions don't exist. - If multiple ellipsis are provided. + If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". """ return Transpose(dims, missing_dims=missing_dims)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 071f8339a0..d5f2cf742d 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -166,7 +166,7 @@ def test_concat_scalar(): def test_xtensor_variable_transpose(): """Test the transpose() method of XTensorVariable.""" x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) - + # Test basic transpose out = x.transpose() fn = xr_function([x], out) @@ -175,30 +175,30 @@ def test_xtensor_variable_transpose(): dims=x.type.dims, ) xr_assert_allclose(fn(x_test), x_test.transpose()) - + # Test transpose with specific dimensions out = x.transpose("c", "a", "b") fn = xr_function([x], out) xr_assert_allclose(fn(x_test), x_test.transpose("c", "a", "b")) - + # Test transpose with ellipsis out = x.transpose("c", ...) fn = xr_function([x], out) xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) - + # Test error cases - with pytest.raises(ValueError, match="Transpose dims.*must match"): + with pytest.raises(ValueError, match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)"): x.transpose("d") - + with pytest.raises(ValueError, match="an index can only have a single ellipsis"): x.transpose("a", ..., "b", ...) - + # Test missing_dims parameter # Test ignore out = x.transpose("c", ..., "d", missing_dims="ignore") fn = xr_function([x], out) xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) - + # Test warn with pytest.warns(UserWarning, match="Dimensions {'d'} do not exist"): out = x.transpose("c", ..., "d", missing_dims="warn") @@ -211,18 +211,18 @@ def test_xtensor_variable_T(): # Test T property with 3D tensor x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4)) out = x.T - + fn = xr_function([x], out) x_test = DataArray( np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), dims=x.type.dims, ) xr_assert_allclose(fn(x_test), x_test.transpose()) - + # Test T property with 2D tensor x = xtensor("x", dims=("a", "b"), shape=(2, 3)) out = x.T - + fn = xr_function([x], out) x_test = DataArray( np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), From 4f010e02a6820ab9edfc270900e532114462cf6b Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 10:57:10 -0400 Subject: [PATCH 11/18] Apply ruff-format to shape.py, type.py, and test_shape.py for consistent formatting --- pytensor/xtensor/shape.py | 18 ++++++++++++++---- pytensor/xtensor/type.py | 5 ++++- tests/xtensor/test_shape.py | 5 ++++- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index dbcdf5ba61..f162435976 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -107,7 +107,9 @@ def expand_ellipsis( if validate: invalid_dims = set(dims) - set(all_dims) if invalid_dims: - raise ValueError(f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}") + raise ValueError( + f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}" + ) return dims if sum(d is ... for d in dims) > 1: @@ -126,7 +128,9 @@ def expand_ellipsis( if validate: invalid_dims = set(pre + post) - set(all_dims) if invalid_dims: - raise ValueError(f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}") + raise ValueError( + f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}" + ) middle = [d for d in all_dims if d not in pre + post] return tuple(pre + middle + post) @@ -134,14 +138,20 @@ def expand_ellipsis( class Transpose(XOp): __props__ = ("dims", "missing_dims") - def __init__(self, dims: tuple[str, ...], missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + def __init__( + self, + dims: tuple[str, ...], + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + ): super().__init__() self.dims = dims self.missing_dims = missing_dims def make_node(self, x): x = as_xtensor(x) - dims = expand_ellipsis(self.dims, x.type.dims, validate=(self.missing_dims == "raise")) + dims = expand_ellipsis( + self.dims, x.type.dims, validate=(self.missing_dims == "raise") + ) # Handle missing dimensions based on missing_dims setting if self.missing_dims == "ignore": diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 812c1c7056..a4055f2c54 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -464,7 +464,9 @@ def imag(self): def real(self): return px.math.real(self) - def transpose(self, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise"): + def transpose( + self, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "raise" + ): """Transpose dimensions of the tensor. Parameters @@ -489,6 +491,7 @@ def transpose(self, *dims, missing_dims: Literal["raise", "warn", "ignore"] = "r If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise". """ from pytensor.xtensor.shape import transpose + return transpose(self, *dims, missing_dims=missing_dims) @property diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index d5f2cf742d..fad488e01d 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -187,7 +187,10 @@ def test_xtensor_variable_transpose(): xr_assert_allclose(fn(x_test), x_test.transpose("c", ...)) # Test error cases - with pytest.raises(ValueError, match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)"): + with pytest.raises( + ValueError, + match="Invalid dimensions: {'d'}. Available dimensions: \\('a', 'b', 'c'\\)", + ): x.transpose("d") with pytest.raises(ValueError, match="an index can only have a single ellipsis"): From f0ea583c49a54dc66430a79e9b59311e5914df54 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 11:03:38 -0400 Subject: [PATCH 12/18] Simplify make_node in Transpose class by combining ignore/warn cases --- pytensor/xtensor/shape.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index f162435976..ee29ad8cb9 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -149,22 +149,15 @@ def __init__( def make_node(self, x): x = as_xtensor(x) - dims = expand_ellipsis( - self.dims, x.type.dims, validate=(self.missing_dims == "raise") - ) + dims = expand_ellipsis(self.dims, x.type.dims, validate=(self.missing_dims == "raise")) # Handle missing dimensions based on missing_dims setting - if self.missing_dims == "ignore": - # Filter out dimensions that don't exist in x.type.dims - dims = tuple(d for d in dims if d in x.type.dims) - # Add remaining dimensions in their original order - remaining_dims = tuple(d for d in x.type.dims if d not in dims) - dims = dims + remaining_dims - elif self.missing_dims == "warn": - missing = set(dims) - set(x.type.dims) - if missing: - warnings.warn(f"Dimensions {missing} do not exist in {x.type.dims}") - # Filter out missing dimensions and add remaining ones + if self.missing_dims in ("ignore", "warn"): + if self.missing_dims == "warn": + missing = set(dims) - set(x.type.dims) + if missing: + warnings.warn(f"Dimensions {missing} do not exist in {x.type.dims}") + # Filter out dimensions that don't exist and add remaining ones dims = tuple(d for d in dims if d in x.type.dims) remaining_dims = tuple(d for d in x.type.dims if d not in dims) dims = dims + remaining_dims From 0125bd2cbe6850356ddf866ae478696866135625 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 11:07:06 -0400 Subject: [PATCH 13/18] Format expand_ellipsis call for better readability --- pytensor/xtensor/shape.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index ee29ad8cb9..7dce658187 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -149,7 +149,9 @@ def __init__( def make_node(self, x): x = as_xtensor(x) - dims = expand_ellipsis(self.dims, x.type.dims, validate=(self.missing_dims == "raise")) + dims = expand_ellipsis( + self.dims, x.type.dims, validate=(self.missing_dims == "raise") + ) # Handle missing dimensions based on missing_dims setting if self.missing_dims in ("ignore", "warn"): From 30e1a424fa57264f9929553eba722b091084d48f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 21 May 2025 19:11:02 +0200 Subject: [PATCH 14/18] WIP Implement index operations for XTensorVariables --- pytensor/xtensor/__init__.py | 1 - pytensor/xtensor/indexing.py | 142 +++++++++++++++++++++++++ pytensor/xtensor/rewriting/__init__.py | 1 + pytensor/xtensor/rewriting/indexing.py | 102 ++++++++++++++++++ pytensor/xtensor/type.py | 114 +++++++++++++++++++- tests/xtensor/test_indexing.py | 119 +++++++++++++++++++++ 6 files changed, 476 insertions(+), 3 deletions(-) create mode 100644 pytensor/xtensor/indexing.py create mode 100644 pytensor/xtensor/rewriting/indexing.py create mode 100644 tests/xtensor/test_indexing.py diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index a72bf66c79..06265e40de 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -7,7 +7,6 @@ ) from pytensor.xtensor.shape import concat from pytensor.xtensor.type import ( - XTensorType, as_xtensor, xtensor, xtensor_constant, diff --git a/pytensor/xtensor/indexing.py b/pytensor/xtensor/indexing.py new file mode 100644 index 0000000000..33eb29f051 --- /dev/null +++ b/pytensor/xtensor/indexing.py @@ -0,0 +1,142 @@ +# HERE LIE DRAGONS +# Uselful links to make sense of all the numpy/xarray complexity +# https://numpy.org/devdocs//user/basics.indexing.html +# https://numpy.org/neps/nep-0021-advanced-indexing.html +# https://docs.xarray.dev/en/latest/user-guide/indexing.html +# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html + +from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.scalar.basic import discrete_dtypes +from pytensor.tensor import TensorType +from pytensor.tensor.basic import as_tensor +from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice +from pytensor.xtensor.basic import XOp, xtensor_from_tensor +from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor + + +def as_idx_variable(idx): + if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)): + raise TypeError( + "XTensors do not support indexing with None (np.newaxis), use expand_dims instead" + ) + if isinstance(idx, slice): + idx = make_slice(idx) + elif isinstance(idx, Variable) and isinstance(idx.type, SliceType): + pass + elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str): + # Special case for ("x", array) that xarray supports + # TODO: Check if this can be used to rename existing xarray dimensions or only for numpy + dim, idx = idx + idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,)) + else: + # Must be integer indices, we already counted for None and slices + try: + idx = as_xtensor(idx) + except TypeError: + idx = as_tensor(idx) + if idx.type.dtype == "bool": + raise NotImplementedError("Boolean indexing not yet supported") + if idx.type.dtype not in discrete_dtypes: + raise TypeError("Numerical indices must be integers or boolean") + if idx.type.dtype == "bool" and idx.type.ndim == 0: + # This can't be triggered right now, but will once we lift the boolean restriction + raise NotImplementedError("Scalar boolean indices not supported") + return idx + + +def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None: + if dim_length is None: + return None + if isinstance(slc, Constant): + d = slc.data + start, stop, step = d.start, d.stop, d.step + elif slc.owner is None: + # It's a root variable no way of knowing what we're getting + return None + else: + # It's a MakeSliceOp + start, stop, step = slc.owner.inputs + if isinstance(start, Constant): + start = start.data + else: + return None + if isinstance(stop, Constant): + stop = stop.data + else: + return None + if isinstance(step, Constant): + step = step.data + else: + return None + return len(range(*slice(start, stop, step).indices(dim_length))) + + +class Index(XOp): + __props__ = () + + def make_node(self, x, *idxs): + x = as_xtensor(x) + idxs = [as_idx_variable(idx) for idx in idxs] + + x_ndim = x.type.ndim + x_dims = x.type.dims + x_shape = x.type.shape + out_dims = [] + out_shape = [] + for i, idx in enumerate(idxs): + if i == x_ndim: + raise IndexError("Too many indices") + if isinstance(idx.type, SliceType): + out_dims.append(x_dims[i]) + out_shape.append(get_static_slice_length(idx, x_shape[i])) + else: + if idx.type.ndim == 0: + # Scalar index, dimension is dropped + continue + + if isinstance(idx.type, TensorType): + if idx.type.ndim > 1: + # Same error that xarray raises + raise IndexError( + "Unlabeled multi-dimensional array cannot be used for indexing" + ) + + # This is implicitly an XTensorVariable with dim matching the indexed one + idx = idxs[i] = xtensor_from_tensor(idx, dims=(x_dims[i],)) + + assert isinstance(idx.type, XTensorType) + + idx_dims = idx.type.dims + for dim in idx_dims: + idx_dim_shape = idx.type.shape[idx_dims.index(dim)] + if dim in out_dims: + # Dim already introduced in output by a previous index + # Update static shape or raise if incompatible + out_dim_pos = out_dims.index(dim) + out_dim_shape = out_shape[out_dim_pos] + if out_dim_shape is None: + # We don't know the size of the dimension yet + out_shape[out_dim_pos] = idx_dim_shape + elif ( + idx_dim_shape is not None and idx_dim_shape != out_dim_shape + ): + raise IndexError( + f"Dimension of indexers mismatch for dim {dim}" + ) + else: + # New dimension + out_dims.append(dim) + out_shape.append(idx_dim_shape) + + for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]): + # Add back any unindexed dimensions + if dim_i not in out_dims: + # If the dimension was not indexed, we keep it as is + out_dims.append(dim_i) + out_shape.append(shape_i) + + output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) + return Apply(self, [x, *idxs], [output]) + + +index = Index() diff --git a/pytensor/xtensor/rewriting/__init__.py b/pytensor/xtensor/rewriting/__init__.py index 7ce55b9256..a65ad0db85 100644 --- a/pytensor/xtensor/rewriting/__init__.py +++ b/pytensor/xtensor/rewriting/__init__.py @@ -1,4 +1,5 @@ import pytensor.xtensor.rewriting.basic +import pytensor.xtensor.rewriting.indexing import pytensor.xtensor.rewriting.reduction import pytensor.xtensor.rewriting.shape import pytensor.xtensor.rewriting.vectorization diff --git a/pytensor/xtensor/rewriting/indexing.py b/pytensor/xtensor/rewriting/indexing.py new file mode 100644 index 0000000000..3d9ac3d99b --- /dev/null +++ b/pytensor/xtensor/rewriting/indexing.py @@ -0,0 +1,102 @@ +from itertools import zip_longest + +from pytensor import as_symbolic +from pytensor.graph import Constant, node_rewriter +from pytensor.tensor import arange, specify_shape +from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor +from pytensor.xtensor.indexing import Index +from pytensor.xtensor.rewriting.utils import register_xcanonicalize +from pytensor.xtensor.type import XTensorType + + +def to_basic_idx(idx): + if isinstance(idx.type, SliceType): + if isinstance(idx, Constant): + return idx.data + elif idx.owner: + # MakeSlice Op + # We transform NoneConsts to regular None so that basic Subtensor can be used if possible + return slice( + *[ + None if isinstance(i.type, NoneTypeT) else i + for i in idx.owner.inputs + ] + ) + else: + return idx + if ( + isinstance(idx.type, XTensorType) + and idx.type.ndim == 0 + and idx.type.dtype != bool + ): + return idx.values + raise TypeError("Cannot convert idx to basic idx") + + +@register_xcanonicalize +@node_rewriter(tracks=[Index]) +def lower_index(fgraph, node): + x, *idxs = node.inputs + [out] = node.outputs + x_tensor = tensor_from_xtensor(x) + + if all( + ( + isinstance(idx.type, SliceType) + or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0) + ) + for idx in idxs + ): + # Special case just basic indexing + x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)] + + else: + # General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing + # May need to convert basic indexing to advanced indexing if it acts on a dimension + # that is also indexed by an advanced index + x_dims = x.type.dims + x_shape = tuple(x.shape) + out_ndim = out.type.ndim + out_xdims = out.type.dims + aligned_idxs = [] + # zip_longest adds the implicit slice(None) + for i, (idx, x_dim) in enumerate( + zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None))) + ): + if isinstance(idx.type, SliceType): + if not any( + ( + isinstance(other_idx.type, XTensorType) + and x_dim in other_idx.dims + ) + for j, other_idx in enumerate(idxs) + if j != i + ): + # We can use basic indexing directly if no other index acts on this dimension + aligned_idxs.append(idx) + else: + # Otherwise we need to convert the basic index into an equivalent advanced indexing + # And align it so it interacts correctly with the other advanced indices + adv_idx_equivalent = arange(x_shape[i])[idx] + ds_order = ["x"] * out_ndim + ds_order[out_xdims.index(x_dim)] = 0 + aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order)) + else: + assert isinstance(idx.type, XTensorType) + if idx.type.ndim == 0: + # Scalar index, we can use it directly + aligned_idxs.append(idx.values) + else: + # Vector index, we need to align the indexing dimensions with the base_dims + ds_order = ["x"] * out_ndim + for j, idx_dim in enumerate(idx.dims): + ds_order[out_xdims.index(idx_dim)] = j + aligned_idxs.append(idx.values.dimshuffle(ds_order)) + x_tensor_indexed = x_tensor[tuple(aligned_idxs)] + # TODO: Align output dimensions if necessary + + # Add lost shape if any + x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape) + new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims) + return [new_out] diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 5b79e9ae57..f99b9261cb 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -1,3 +1,5 @@ +import warnings + from pytensor.tensor import TensorType from pytensor.tensor.math import variadic_mul @@ -10,7 +12,7 @@ XARRAY_AVAILABLE = False from collections.abc import Sequence -from typing import TypeVar +from typing import Any, Literal, TypeVar import numpy as np @@ -339,7 +341,115 @@ def sel(self, *args, **kwargs): raise NotImplementedError("sel not implemented for XTensorVariable") def __getitem__(self, idx): - raise NotImplementedError("Indexing not yet implemnented") + if isinstance(idx, dict): + return self.isel(idx) + + if not isinstance(idx, tuple): + idx = (idx,) + + # Check for ellipsis not in the last position (last one is useless anyway) + if any(idx_item is Ellipsis for idx_item in idx): + if idx.count(Ellipsis) > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Convert intermediate Ellipsis to slice(None) + ellipsis_loc = idx.index(Ellipsis) + n_implied_none_slices = self.type.ndim - (len(idx) - 1) + idx = ( + *idx[:ellipsis_loc], + *((slice(None),) * n_implied_none_slices), + *idx[ellipsis_loc + 1 :], + ) + + return px.indexing.index(self, *idx) + + def isel( + self, + indexers: dict[str, Any] | None = None, + drop: bool = False, # Unused by PyTensor + missing_dims: Literal["raise", "warn", "ignore"] = "raise", + **indexers_kwargs, + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to isel" + ) + indexers = indexers_kwargs + + if missing_dims not in {"raise", "warn", "ignore"}: + raise ValueError( + f"Unrecognized options {missing_dims} for missing_dims argument" + ) + + # Sort indices and pass them to index + dims = self.type.dims + indices = [slice(None)] * self.type.ndim + for key, idx in indexers.items(): + if idx is Ellipsis: + # Xarray raises a less informative error, suggesting indices must be integer + # But slices are also fine + raise TypeError("Ellipsis (...) is an invalid labeled index") + try: + indices[dims.index(key)] = idx + except IndexError: + if missing_dims == "raise": + raise ValueError( + f"Dimension {key} does not exist. Expected one of {dims}" + ) + elif missing_dims == "warn": + warnings.warn( + UserWarning, + f"Dimension {key} does not exist. Expected one of {dims}", + ) + + return px.indexing.index(self, *indices) + + def _head_tail_or_thin( + self, + indexers: dict[str, Any] | int | None, + indexers_kwargs: dict[str, Any], + *, + kind: Literal["head", "tail", "thin"], + ): + if indexers_kwargs: + if indexers is not None: + raise ValueError( + "Cannot pass both indexers and indexers_kwargs to head" + ) + indexers = indexers_kwargs + + if indexers is None: + if kind == "thin": + raise TypeError( + "thin() indexers must be either dict-like or a single integer" + ) + else: + # Default to 5 for head and tail + indexers = {dim: 5 for dim in self.type.dims} + + elif not isinstance(indexers, dict): + indexers = {dim: indexers for dim in self.type.dims} + + if kind == "head": + indices = {dim: slice(None, value) for dim, value in indexers.items()} + elif kind == "tail": + sizes = self.sizes + # Can't use slice(-value, None), in case value is zero + indices = { + dim: slice(sizes[dim] - value, None) for dim, value in indexers.items() + } + elif kind == "thin": + indices = {dim: slice(None, None, value) for dim, value in indexers.items()} + return self.isel(indices) + + def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head") + + def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail") + + def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs): + return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin") # ndarray methods # https://docs.xarray.dev/en/latest/api.html#id7 diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py new file mode 100644 index 0000000000..ae7cb835c0 --- /dev/null +++ b/tests/xtensor/test_indexing.py @@ -0,0 +1,119 @@ +import numpy as np +import pytest +from xarray import DataArray + +from pytensor.tensor import tensor +from pytensor.xtensor import xtensor +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function + + +@pytest.mark.parametrize( + "indices", + [ + (0,), + (slice(1, None),), + (slice(None, -1),), + (slice(None, None, -1),), + (0, slice(None), -1, slice(1, None)), + (..., 0, -1), + (0, ..., -1), + (0, -1, ...), + ], +) +@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"]) +def test_basic_indexing(labeled, indices): + if ... in indices and labeled: + pytest.skip("Ellipsis not supported with labeled indexing") + + dims = ("a", "b", "c", "d") + x = xtensor(dims=dims, shape=(2, 3, 5, 7)) + + if labeled: + shufled_dims = tuple(np.random.permutation(dims)) + indices = dict(zip(shufled_dims, indices, strict=False)) + out = x[indices] + + fn = xr_function([x], out) + x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + x.type.shape + ) + x_test = DataArray(x_test_values, dims=x.type.dims) + res = fn(x_test) + expected_res = x_test[indices] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_existing_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Three equivalent ways of indexing a->a + y = x[idx] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[idx_test] + xr_assert_allclose(res, expected_res) + + y = x[(("a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_on_new_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equvilant ways of indexing a->new_a + y = x[(("new_a", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[(("new_a", idx_test),)] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="new_a")] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="new_a")] + xr_assert_allclose(res, expected_res) + + +def test_single_vector_indexing_interacting_with_exisiting_dim(): + x = xtensor(dims=("a", "b"), shape=(3, 5)) + idx = tensor("idx", dtype=int, shape=(4,)) + xidx = xtensor("idx", dtype=int, shape=(4,), dims=("a",)) + + x_test = xr_arange_like(x) + idx_test = np.array([0, 1, 0, 2], dtype=int) + xidx_test = DataArray(idx_test, dims=("a",)) + + # Two equivalent ways of indexing a->b + # By labeling the index on a, as "b", we cause pointwise indexing between the two dimensions. + y = x[(("b", idx),)] + fn = xr_function([x, idx], y) + res = fn(x_test, idx_test) + expected_res = x_test[("b", idx_test), 1:] + xr_assert_allclose(res, expected_res) + + y = x[xidx.rename(a="b")] + fn = xr_function([x, xidx], y) + res = fn(x_test, xidx_test) + expected_res = x_test[xidx_test.rename(a="b"), 1:] + xr_assert_allclose(res, expected_res) From 29b954a25173cbcb4996b1606ecfe4efbee9d628 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 26 May 2025 17:38:05 +0200 Subject: [PATCH 15/18] Add diff method to XTensorVariable --- pytensor/xtensor/type.py | 9 +++++++++ tests/xtensor/test_indexing.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index f99b9261cb..6e6ee7c8e5 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -502,6 +502,15 @@ def cumsum(self, dim): def cumprod(self, dim): return px.reduction.cumprod(self, dim) + def diff(self, dim, n=1): + """Compute the n-th discrete difference along the given dimension.""" + slice1 = {dim: slice(1, None)} + slice2 = {dim: slice(None, -1)} + x = self + for _ in range(n): + x = x[slice1] - x[slice2] + return x + class XTensorConstantSignature(tuple): def __eq__(self, other): diff --git a/tests/xtensor/test_indexing.py b/tests/xtensor/test_indexing.py index ae7cb835c0..8010211b23 100644 --- a/tests/xtensor/test_indexing.py +++ b/tests/xtensor/test_indexing.py @@ -117,3 +117,22 @@ def test_single_vector_indexing_interacting_with_exisiting_dim(): res = fn(x_test, xidx_test) expected_res = x_test[xidx_test.rename(a="b"), 1:] xr_assert_allclose(res, expected_res) + + +@pytest.mark.parametrize("n", ["implicit", 1, 2]) +@pytest.mark.parametrize("dim", ["a", "b"]) +def test_diff(dim, n): + x = xtensor(dims=("a", "b"), shape=(7, 11)) + if n == "implicit": + out = x.diff(dim) + else: + out = x.diff(dim, n=n) + + fn = xr_function([x], out) + x_test = xr_arange_like(x) + res = fn(x_test) + if n == "implicit": + expected_res = x_test.diff(dim) + else: + expected_res = x_test.diff(dim, n=n) + xr_assert_allclose(res, expected_res) From a76b15ef1c80cf7a610bcfe941d4f998691e50c6 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 14:40:35 -0400 Subject: [PATCH 16/18] Format and simplify expand_ellipsis; auto-fix with pre-commit; update tests --- pytensor/xtensor/shape.py | 75 +++++++++++++++++-------------------- tests/xtensor/test_shape.py | 43 +++++++-------------- 2 files changed, 48 insertions(+), 70 deletions(-) diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 7dce658187..2c7f742a42 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -76,7 +76,10 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) def expand_ellipsis( - dims: tuple[str, ...], all_dims: tuple[str, ...], validate: bool = True + dims: tuple[str, ...], + all_dims: tuple[str, ...], + validate: bool = True, + missing_dims: Literal["raise", "warn", "ignore"] = "raise", ) -> tuple[str, ...]: """Expand ellipsis in dimension permutation. @@ -88,6 +91,11 @@ def expand_ellipsis( All available dimensions validate : bool, default True Whether to check that all non-ellipsis elements in dims are valid dimension names. + missing_dims : {"raise", "warn", "ignore"}, optional + How to handle dimensions that don't exist in all_dims: + - "raise": Raise an error if any dimensions don't exist (default) + - "warn": Warn if any dimensions don't exist + - "ignore": Silently ignore any dimensions that don't exist Returns ------- @@ -99,38 +107,39 @@ def expand_ellipsis( ValueError If more than one ellipsis is present in dims. If any non-ellipsis element in dims is not a valid dimension name and validate is True. + If missing_dims is "raise" and any dimension in dims doesn't exist in all_dims. """ + # Handle empty or full ellipsis case if dims == () or dims == (...,): return tuple(reversed(all_dims)) - if ... not in dims: - if validate: - invalid_dims = set(dims) - set(all_dims) - if invalid_dims: + # Check for multiple ellipses + if dims.count(...) > 1: + raise ValueError("an index can only have a single ellipsis ('...')") + + # Validate dimensions if requested + if validate: + invalid_dims = set(dims) - {..., *all_dims} + if invalid_dims: + if missing_dims == "raise": raise ValueError( f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}" ) - return dims + elif missing_dims == "warn": + warnings.warn(f"Dimensions {invalid_dims} do not exist in {all_dims}") - if sum(d is ... for d in dims) > 1: - raise ValueError("an index can only have a single ellipsis ('...')") + # Handle missing dimensions if not raising + if missing_dims in ("ignore", "warn"): + dims = tuple(d for d in dims if d in all_dims or d is ...) - pre = [] - post = [] - found = False - for d in dims: - if d is ...: - found = True - elif not found: - pre.append(d) - else: - post.append(d) - if validate: - invalid_dims = set(pre + post) - set(all_dims) - if invalid_dims: - raise ValueError( - f"Invalid dimensions: {invalid_dims}. Available dimensions: {all_dims}" - ) + # If no ellipsis, just return the dimensions + if ... not in dims: + return dims + + # Handle ellipsis expansion + ellipsis_idx = dims.index(...) + pre = list(dims[:ellipsis_idx]) + post = list(dims[ellipsis_idx + 1 :]) middle = [d for d in all_dims if d not in pre + post] return tuple(pre + middle + post) @@ -140,7 +149,7 @@ class Transpose(XOp): def __init__( self, - dims: tuple[str, ...], + dims: tuple[str | Literal[...], ...], missing_dims: Literal["raise", "warn", "ignore"] = "raise", ): super().__init__() @@ -150,23 +159,9 @@ def __init__( def make_node(self, x): x = as_xtensor(x) dims = expand_ellipsis( - self.dims, x.type.dims, validate=(self.missing_dims == "raise") + self.dims, x.type.dims, validate=True, missing_dims=self.missing_dims ) - # Handle missing dimensions based on missing_dims setting - if self.missing_dims in ("ignore", "warn"): - if self.missing_dims == "warn": - missing = set(dims) - set(x.type.dims) - if missing: - warnings.warn(f"Dimensions {missing} do not exist in {x.type.dims}") - # Filter out dimensions that don't exist and add remaining ones - dims = tuple(d for d in dims if d in x.type.dims) - remaining_dims = tuple(d for d in x.type.dims if d not in dims) - dims = dims + remaining_dims - else: # "raise" - if set(dims) != set(x.type.dims): - raise ValueError(f"Transpose dims {dims} must match {x.type.dims}") - output = xtensor( dtype=x.type.dtype, shape=tuple(x.type.shape[x.type.dims.index(d)] for d in dims), diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index fad488e01d..467b8e8e7d 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -7,12 +7,16 @@ from itertools import chain, combinations import numpy as np -from xarray import DataArray from xarray import concat as xr_concat from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like +from tests.xtensor.util import ( + xr_arange_like, + xr_assert_allclose, + xr_function, + xr_random_like, +) def powerset(iterable, min_group_size=0): @@ -40,10 +44,7 @@ def test_transpose(): outs = [transpose(x, *perm) for perm in permutations] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [x_test.transpose(*perm) for perm in permutations] for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): @@ -59,10 +60,7 @@ def test_stack(): ] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [ @@ -79,10 +77,7 @@ def test_stack_single_dim(): assert out.type.dims == ("b", "c", "d") fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) fn.fn.dprint(print_type=True) res = fn(x_test) expected_res = x_test.stack(d=["a"]) @@ -94,10 +89,7 @@ def test_multiple_stacks(): out = stack(x, new_dim1=("a", "b"), new_dim2=("c", "d")) fn = xr_function([x], [out]) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) xr_assert_allclose(res[0], expected_res) @@ -170,10 +162,7 @@ def test_xtensor_variable_transpose(): # Test basic transpose out = x.transpose() fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) xr_assert_allclose(fn(x_test), x_test.transpose()) # Test transpose with specific dimensions @@ -216,10 +205,7 @@ def test_xtensor_variable_T(): out = x.T fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) xr_assert_allclose(fn(x_test), x_test.transpose()) # Test T property with 2D tensor @@ -227,8 +213,5 @@ def test_xtensor_variable_T(): out = x.T fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) xr_assert_allclose(fn(x_test), x_test.transpose()) From af14c90ec61fc98c85139bb319e7d1812fa900f8 Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 15:08:24 -0400 Subject: [PATCH 17/18] Improve expand_dims: add tests, fix reshape usage, and ensure code style compliance --- pytensor/xtensor/rewriting/shape.py | 28 ++++++- pytensor/xtensor/shape.py | 116 ++++++++++++++++++++++++++++ tests/xtensor/test_shape.py | 35 ++++++++- 3 files changed, 174 insertions(+), 5 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index 03deb9a91c..a33401a5c4 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Concat, Stack, Transpose +from pytensor.xtensor.shape import Concat, ExpandDims, Squeeze, Stack, Transpose @register_xcanonicalize @@ -86,3 +86,29 @@ def lower_transpose(fgraph, node): x_tensor_transposed = x_tensor.transpose(perm) new_out = xtensor_from_tensor(x_tensor_transposed, dims=out_dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[ExpandDims]) +def lower_expand_dims(fgraph, node): + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + x_tensor_expanded = x_tensor.reshape((*x_tensor.shape, 1)) + new_out = xtensor_from_tensor(x_tensor_expanded, dims=node.outputs[0].type.dims) + return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[Squeeze]) +def lower_squeeze(fgraph, node): + [x] = node.inputs + x_tensor = tensor_from_xtensor(x) + if node.op.dim is not None: + dim_idx = x.type.dims.index(node.op.dim) + x_tensor_squeezed = x_tensor.reshape( + tuple(s for i, s in enumerate(x_tensor.shape) if i != dim_idx) + ) + else: + x_tensor_squeezed = x_tensor.reshape(tuple(s for s in x_tensor.shape if s != 1)) + new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 2c7f742a42..4ac371a6ca 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -249,3 +249,119 @@ def make_node(self, *inputs: Variable) -> Apply: def concat(xtensors, dim: str): return Concat(dim=dim)(*xtensors) + + +class ExpandDims(XOp): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the new dimension. If None, the dimension will be unnamed. + """ + + def __init__(self, dim): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Check if dimension already exists + if self.dim is not None and self.dim in x.type.dims: + raise ValueError(f"Dimension {self.dim} already exists") + + # Create new dimensions list with the new dimension + new_dims = list(x.type.dims) + new_dims.append(self.dim) + + # Create new shape with the new dimension + new_shape = list(x.type.shape) + new_shape.append(1) + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def expand_dims(x, dim: str): + """Add a new dimension to an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str + The name of the new dimension + + Returns + ------- + XTensorVariable + A new tensor with the expanded dimension + """ + return ExpandDims(dim=dim)(x) + + +class Squeeze(XOp): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + dim : str or None + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + """ + + def __init__(self, dim=None): + self.dim = dim + + def make_node(self, x): + x = as_xtensor(x) + + # Get the index of the dimension to remove + if self.dim is not None: + if self.dim not in x.type.dims: + raise ValueError(f"Dimension {self.dim} not found") + dim_idx = x.type.dims.index(self.dim) + if x.type.shape[dim_idx] != 1: + raise ValueError( + f"Dimension {self.dim} has size {x.type.shape[dim_idx]}, not 1" + ) + else: + # Find all dimensions of size 1 + dim_idx = [i for i, s in enumerate(x.type.shape) if s == 1] + if not dim_idx: + raise ValueError("No dimensions of size 1 to remove") + + # Create new dimensions and shape lists + new_dims = list(x.type.dims) + new_shape = list(x.type.shape) + if self.dim is not None: + new_dims.pop(dim_idx) + new_shape.pop(dim_idx) + else: + # Remove all dimensions of size 1 + new_dims = [d for i, d in enumerate(new_dims) if i not in dim_idx] + new_shape = [s for i, s in enumerate(new_shape) if i not in dim_idx] + + output = xtensor( + dtype=x.type.dtype, shape=tuple(new_shape), dims=tuple(new_dims) + ) + return Apply(self, [x], [output]) + + +def squeeze(x, dim=None): + """Remove a dimension of size 1 from an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input tensor + dim : str or None, optional + The name of the dimension to remove. If None, all dimensions of size 1 will be removed. + + Returns + ------- + XTensorVariable + A new tensor with the specified dimension removed + """ + return Squeeze(dim=dim)(x) diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 467b8e8e7d..3b37cd347c 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -1,15 +1,14 @@ # ruff: noqa: E402 +import numpy as np import pytest +from xarray import concat as xr_concat pytest.importorskip("xarray") from itertools import chain, combinations -import numpy as np -from xarray import concat as xr_concat - -from pytensor.xtensor.shape import concat, stack, transpose +from pytensor.xtensor.shape import concat, expand_dims, stack, transpose from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -215,3 +214,31 @@ def test_xtensor_variable_T(): fn = xr_function([x], out) x_test = xr_arange_like(x) xr_assert_allclose(fn(x_test), x_test.transpose()) + + +def test_expand_dims(): + # Test 1D tensor expansion + x = xtensor("x", dims=("city",), shape=(3,)) + y = expand_dims(x, "country") + assert y.type.dims == ("city", "country") + assert y.type.shape == (3, 1) + + # Test 2D tensor expansion + x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 3)) + y2d = expand_dims(x2d, "batch") + assert y2d.type.dims == ("row", "col", "batch") + assert y2d.type.shape == (2, 3, 1) + + # Test expansion with different dimension name + z = expand_dims(x, "time") + assert z.type.dims == ("city", "time") + assert z.type.shape == (3, 1) + + # Test that expanding with an existing dimension raises an error + with pytest.raises(ValueError): + expand_dims(y, "city") + + # Test that expanding with None dimension works + z = expand_dims(x, None) + assert z.type.dims == ("city", None) + assert z.type.shape == (3, 1) From 15f4c485b541cf5ac32e886fd60232bb5d4a890d Mon Sep 17 00:00:00 2001 From: Allen Downey Date: Wed, 28 May 2025 15:46:24 -0400 Subject: [PATCH 18/18] Implement squeeze --- pytensor/xtensor/rewriting/shape.py | 9 +--- tests/xtensor/test_shape.py | 72 ++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index a33401a5c4..2e9c9b6f7f 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -103,12 +103,7 @@ def lower_expand_dims(fgraph, node): def lower_squeeze(fgraph, node): [x] = node.inputs x_tensor = tensor_from_xtensor(x) - if node.op.dim is not None: - dim_idx = x.type.dims.index(node.op.dim) - x_tensor_squeezed = x_tensor.reshape( - tuple(s for i, s in enumerate(x_tensor.shape) if i != dim_idx) - ) - else: - x_tensor_squeezed = x_tensor.reshape(tuple(s for s in x_tensor.shape if s != 1)) + expected_shape = tuple(node.outputs[0].type.shape) + x_tensor_squeezed = x_tensor.reshape(expected_shape) new_out = xtensor_from_tensor(x_tensor_squeezed, dims=node.outputs[0].type.dims) return [new_out] diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 3b37cd347c..57f182650f 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -8,7 +8,7 @@ from itertools import chain, combinations -from pytensor.xtensor.shape import concat, expand_dims, stack, transpose +from pytensor.xtensor.shape import concat, expand_dims, squeeze, stack, transpose from pytensor.xtensor.type import xtensor from tests.xtensor.util import ( xr_arange_like, @@ -242,3 +242,73 @@ def test_expand_dims(): z = expand_dims(x, None) assert z.type.dims == ("city", None) assert z.type.shape == (3, 1) + + +def test_squeeze(): + # Test 1D tensor with no squeezable dimensions + x = xtensor("x", dims=("city",), shape=(3,)) + with pytest.raises(ValueError, match="No dimensions of size 1 to remove"): + squeeze(x) + + # Test 2D tensor with one squeezable dimension + x2d = xtensor("x2d", dims=("row", "col"), shape=(2, 1)) + y2d = squeeze(x2d) + assert y2d.type.dims == ("row",) + assert y2d.type.shape == (2,) + + # Test 3D tensor with multiple squeezable dimensions + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y3d = squeeze(x3d) + assert y3d.type.dims == ("row",) + assert y3d.type.shape == (2,) + + # Test squeezing specific dimension + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y3d = squeeze(x3d, dim="batch") + assert y3d.type.dims == ("row", "col") + assert y3d.type.shape == (2, 1) + + # Test squeezing non-existent dimension + with pytest.raises(ValueError, match="Dimension time not found"): + squeeze(x3d, dim="time") + + # Test squeezing dimension with size > 1 + x3d = xtensor("x3d", dims=("batch", "row", "col"), shape=(2, 2, 1)) + with pytest.raises(ValueError, match="Dimension batch has size 2, not 1"): + squeeze(x3d, dim="batch") + + # Test functional interface + fn = xr_function([x2d], y2d) + x_test = xr_arange_like(x2d) + res = fn(x_test) + expected_res = x_test.squeeze() + xr_assert_allclose(res, expected_res) + + # Test squeezing a tensor with multiple squeezable dimensions + x_multi = xtensor("x_multi", dims=("batch", "row", "col"), shape=(1, 2, 1)) + y_multi = squeeze(x_multi) + assert y_multi.type.dims == ("row",) + assert y_multi.type.shape == (2,) + fn_multi = xr_function([x_multi], y_multi) + x_multi_test = xr_arange_like(x_multi) + res_multi = fn_multi(x_multi_test) + expected_res_multi = x_multi_test.squeeze() + xr_assert_allclose(res_multi, expected_res_multi) + + +def test_lower_squeeze(): + from pytensor.xtensor.rewriting.shape import lower_squeeze + from pytensor.xtensor.shape import squeeze + from pytensor.xtensor.type import xtensor + + # Create a tensor with a squeezable dimension + x = xtensor("x", dims=("row", "col"), shape=(2, 1)) + y = squeeze(x) + + class DummyFGraph: + pass + + node = type("Node", (), {"inputs": [x], "op": y.owner.op, "outputs": [y]})() + [out] = lower_squeeze.transform(DummyFGraph(), node) + assert out.type.dims == ("row",) + assert out.type.shape == (2,)