Skip to content

Update test_shape.py #1477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: labeled_tensors
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4551,7 +4551,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
7 changes: 1 addition & 6 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,12 +3021,7 @@ def make_node(self, x, y, *inputs):
return Apply(
self,
(x, y, *new_inputs),
[
tensor(
dtype=x.type.dtype,
shape=tuple(1 if s == 1 else None for s in x.type.shape),
)
],
[x.type()],
)

def perform(self, node, inputs, out_):
Expand Down
16 changes: 16 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
special,
)
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
104 changes: 104 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from collections.abc import Sequence

from pytensor.compile import ViewOp
from pytensor.graph import Apply, Op
from pytensor.link.c.op import COp
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""

def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


class XTypeCastOp(COp):
"""Base class for Ops that type cast between TensorType and XTensorType.

This is like a `ViewOp` but without the expectation the input and output have identical types.
"""

view_map = {0: [0]}

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]

def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]

code, _ = ViewOp.c_code_and_version[TensorType]
return code % locals()

def c_code_cache_version(self):
_, version = ViewOp.c_code_and_version[TensorType]
return (version,)


class TensorFromXTensor(XTypeCastOp):
__props__ = ()

def make_node(self, x):
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])


tensor_from_xtensor = TensorFromXTensor()


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)

def make_node(self, x):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims

def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict

x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except IndexError:
raise ValueError(
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
Loading