Skip to content

Commit f846621

Browse files
committed
Refactor XElemwise and XBlockwise
1 parent eefd563 commit f846621

File tree

4 files changed

+67
-78
lines changed

4 files changed

+67
-78
lines changed

pytensor/xtensor/linalg.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def cholesky(
2828
((dims[0], dims[1]),),
2929
((dims[0], dims[1]),),
3030
)
31-
x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims)
31+
x_op = XBlockwise(core_op, core_dims=core_dims)
3232
return x_op(x)
3333

3434

@@ -48,18 +48,15 @@ def solve(
4848
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
4949
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
5050
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
51-
output_core_dims = ((m2_dim,),)
51+
# The shared dim disappears in the output
52+
output_core_dims = ((m1_dim,),)
5253
elif len(dims) == 3:
5354
b_ndim = 2
5455
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
5556
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
5657
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
57-
output_core_dims = (
58-
(
59-
m2_dim,
60-
n_dim,
61-
),
62-
)
58+
# The shared dim disappears in the output
59+
output_core_dims = ((m1_dim, n_dim),)
6360
else:
6461
raise ValueError("Solve dims must have length 2 or 3")
6562

@@ -68,7 +65,6 @@ def solve(
6865
)
6966
x_op = XBlockwise(
7067
core_op,
71-
signature=core_op.gufunc_signature,
7268
core_dims=(input_core_dims, output_core_dims),
7369
)
7470
return x_op(a, b)

pytensor/xtensor/rewriting/vectorization.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def lower_elemwise(fgraph, node):
3737
@node_rewriter(tracks=[XBlockwise])
3838
def lower_blockwise(fgraph, node):
3939
op: XBlockwise = node.op
40-
batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0])
40+
batch_ndim = node.outputs[0].type.ndim - len(op.core_dims[1][0])
4141
batch_dims = node.outputs[0].type.dims[:batch_ndim]
4242

4343
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
@@ -53,7 +53,19 @@ def lower_blockwise(fgraph, node):
5353
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
5454
tensor_inputs.append(tensor_inp)
5555

56-
tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature)
56+
signature = op.signature or getattr(op.core_op, "gufunc_signature", None)
57+
if signature is None:
58+
# Build a signature based on the core dimensions
59+
# The Op signature could be more strict, as core_dims will never be repeated, but no functionality depends greatly on it
60+
inputs_core_dims, outputs_core_dims = op.core_dims
61+
inputs_signature = ",".join(
62+
f"({', '.join(inp_core_dims)})" for inp_core_dims in inputs_core_dims
63+
)
64+
outputs_signature = ",".join(
65+
f"({', '.join(out_core_dims)})" for out_core_dims in outputs_core_dims
66+
)
67+
signature = f"{inputs_signature}->{outputs_signature}"
68+
tensor_op = Blockwise(core_op=op.core_op, signature=signature)
5769
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
5870

5971
# Convert output Tensors to XTensors

pytensor/xtensor/vectorization.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,27 @@
33
from pytensor import scalar as ps
44
from pytensor.graph import Apply, Op
55
from pytensor.tensor import tensor
6-
from pytensor.tensor.utils import _parse_gufunc_signature
76
from pytensor.xtensor.basic import XOp
87
from pytensor.xtensor.type import as_xtensor, xtensor
98

109

10+
def combine_dims_and_shape(inputs):
11+
dims_and_shape: dict[str, int | None] = {}
12+
for inp in inputs:
13+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
14+
if dim not in dims_and_shape:
15+
dims_and_shape[dim] = dim_length
16+
elif dim_length is not None:
17+
# Check for conflicting shapes
18+
if (dims_and_shape[dim] is not None) and (
19+
dims_and_shape[dim] != dim_length
20+
):
21+
raise ValueError(f"Dimension {dim} has conflicting shapes")
22+
# Keep the non-None shape
23+
dims_and_shape[dim] = dim_length
24+
return dims_and_shape
25+
26+
1127
class XElemwise(XOp):
1228
__props__ = ("scalar_op",)
1329

@@ -22,20 +38,7 @@ def make_node(self, *inputs):
2238
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
2339
)
2440

25-
dims_and_shape: dict[str, int | None] = {}
26-
for inp in inputs:
27-
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
28-
if dim not in dims_and_shape:
29-
dims_and_shape[dim] = dim_length
30-
elif dim_length is not None:
31-
# Check for conflicting shapes
32-
if (dims_and_shape[dim] is not None) and (
33-
dims_and_shape[dim] != dim_length
34-
):
35-
raise ValueError(f"Dimension {dim} has conflicting shapes")
36-
# Keep the non-None shape
37-
dims_and_shape[dim] = dim_length
38-
41+
dims_and_shape = combine_dims_and_shape(inputs)
3942
if dims_and_shape:
4043
output_dims, output_shape = zip(*dims_and_shape.items())
4144
else:
@@ -53,48 +56,33 @@ def make_node(self, *inputs):
5356

5457

5558
class XBlockwise(XOp):
56-
__props__ = ("core_op", "signature", "core_dims")
59+
__props__ = ("core_op", "core_dims")
5760

5861
def __init__(
5962
self,
6063
core_op: Op,
61-
signature: str,
6264
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]],
65+
signature: str | None = None,
6366
):
6467
super().__init__()
6568
self.core_op = core_op
66-
self.signature = signature
67-
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
6869
self.core_dims = core_dims
70+
self.signature = signature # Only used for lowering, not for validation
6971

7072
def make_node(self, *inputs):
7173
inputs = [as_xtensor(i) for i in inputs]
72-
if len(inputs) != len(self.inputs_sig):
74+
if len(inputs) != len(self.core_dims[0]):
7375
raise ValueError(
74-
f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}"
76+
f"Wrong number of inputs, expected {len(self.core_dims[0])}, got {len(inputs)}"
7577
)
7678

77-
dims_and_shape: dict[str, int | None] = {}
78-
for inp in inputs:
79-
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
80-
if dim not in dims_and_shape:
81-
dims_and_shape[dim] = dim_length
82-
elif dim_length is not None:
83-
# Check for conflicting shapes
84-
if (dims_and_shape[dim] is not None) and (
85-
dims_and_shape[dim] != dim_length
86-
):
87-
raise ValueError(f"Dimension {dim} has conflicting shapes")
88-
# Keep the non-None shape
89-
dims_and_shape[dim] = dim_length
79+
dims_and_shape = combine_dims_and_shape(inputs)
9080

9181
core_inputs_dims, core_outputs_dims = self.core_dims
92-
# TODO: Avoid intermediate dict
93-
core_dims = set(chain.from_iterable(core_inputs_dims))
94-
batched_dims_and_shape = {
95-
k: v for k, v in dims_and_shape.items() if k not in core_dims
96-
}
97-
batch_dims, batch_shape = zip(*batched_dims_and_shape.items())
82+
core_input_dims_set = set(chain.from_iterable(core_inputs_dims))
83+
batch_dims, batch_shape = zip(
84+
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set)
85+
)
9886

9987
dummy_core_inputs = []
10088
for inp, core_inp_dims in zip(inputs, core_inputs_dims):

tests/xtensor/test_linalg.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: E402
22
import pytest
3+
from xtensor.util import xr_assert_allclose, xr_function
34

45

56
pytest.importorskip("xarray")
@@ -14,7 +15,6 @@
1415
solve as xr_solve,
1516
)
1617

17-
from pytensor import function
1818
from pytensor.xtensor.linalg import cholesky, solve
1919
from pytensor.xtensor.type import xtensor
2020

@@ -25,59 +25,52 @@ def test_cholesky():
2525
assert y.type.dims == ("batch", "b", "a")
2626
assert y.type.shape == (3, 4, 4)
2727

28-
fn = function([x], y)
28+
fn = xr_function([x], y)
2929
rng = np.random.default_rng(25)
30-
x_ = rng.random(size=(4, 3, 3))
30+
x_ = rng.random(size=(3, 4, 4))
3131
x_ = x_ @ x_.mT
3232
x_test = DataArray(x_.transpose(1, 0, 2), dims=x.type.dims)
33-
np.testing.assert_allclose(
34-
fn(x_test.values),
35-
xr_cholesky(x_test, dims=["b", "a"]).values,
33+
xr_assert_allclose(
34+
fn(x_test),
35+
xr_cholesky(x_test, dims=["b", "a"]),
3636
)
3737

3838

3939
def test_solve_vector_b():
4040
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
4141
b = xtensor("b", dims=("city", "planet"), shape=(None, 2))
4242
x = solve(a, b, dims=["country", "city"])
43-
assert x.type.dims == ("galaxy", "planet", "city")
44-
assert x.type.shape == (
45-
1,
46-
2,
47-
None,
48-
) # Core Solve doesn't make use of the fact A must be square in the static shape
43+
assert x.type.dims == ("galaxy", "planet", "country")
44+
# Core Solve doesn't make use of the fact A must be square in the static shape
45+
assert x.type.shape == (1, 2, None)
4946

50-
fn = function([a, b], x)
47+
fn = xr_function([a, b], x)
5148

5249
rng = np.random.default_rng(25)
5350
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
5451
b_test = DataArray(rng.random(size=(4, 2)), dims=b.type.dims)
5552

56-
np.testing.assert_allclose(
57-
fn(a_test.values, b_test.values),
58-
xr_solve(a_test, b_test, dims=["country", "city"]).values,
53+
xr_assert_allclose(
54+
fn(a_test, b_test),
55+
xr_solve(a_test, b_test, dims=["country", "city"]),
5956
)
6057

6158

6259
def test_solve_matrix_b():
6360
a = xtensor("a", dims=("city", "country", "galaxy"), shape=(None, 4, 1))
6461
b = xtensor("b", dims=("district", "city", "planet"), shape=(5, None, 2))
6562
x = solve(a, b, dims=["country", "city", "district"])
66-
assert x.type.dims == ("galaxy", "planet", "city", "district")
67-
assert x.type.shape == (
68-
1,
69-
2,
70-
None,
71-
5,
72-
) # Core Solve doesn't make use of the fact A must be square in the static shape
63+
assert x.type.dims == ("galaxy", "planet", "country", "district")
64+
# Core Solve doesn't make use of the fact A must be square in the static shape
65+
assert x.type.shape == (1, 2, None, 5)
7366

74-
fn = function([a, b], x)
67+
fn = xr_function([a, b], x)
7568

7669
rng = np.random.default_rng(25)
7770
a_test = DataArray(rng.random(size=(4, 4, 1)), dims=a.type.dims)
7871
b_test = DataArray(rng.random(size=(5, 4, 2)), dims=b.type.dims)
7972

80-
np.testing.assert_allclose(
81-
fn(a_test.values, b_test.values),
82-
xr_solve(a_test, b_test, dims=["country", "city", "district"]).values,
73+
xr_assert_allclose(
74+
fn(a_test, b_test),
75+
xr_solve(a_test, b_test, dims=["country", "city", "district"]),
8376
)

0 commit comments

Comments
 (0)