diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f36f8888ba..b81439f960 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1982,6 +1982,62 @@ def transpose(x, axes=None): return ret +def matrix_transpose(x: "TensorLike") -> TensorVariable: + """ + Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor. + + Parameters + ---------- + x : array_like + Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions + of the matrices. Each matrix is of shape (M, N). + + Returns + ------- + out : tensor + Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix + in the input tensor has been transposed along the last two dimensions. + + Examples + -------- + >>> import pytensor as pt + >>> import numpy as np + >>> x = np.arange(24).reshape((2, 3, 4)) + [[[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] + + [[12 13 14 15] + [16 17 18 19] + [20 21 22 23]]] + + + >>> pt.matrix_transpose(x).eval() + [[[ 0 4 8] + [ 1 5 9] + [ 2 6 10] + [ 3 7 11]] + + [[12 16 20] + [13 17 21] + [14 18 22] + [15 19 23]]] + + + Notes + ----- + This function transposes each 2-dimensional matrix within the input tensor along + the last two dimensions. If the input tensor has more than two dimensions, it + transposes each 2-dimensional matrix independently while preserving other dimensions. + """ + x = as_tensor_variable(x) + if x.ndim < 2: + raise ValueError( + f"Input array must be at least 2-dimensional, but it is {x.ndim}" + ) + return swapaxes(x, -1, -2) + + def split(x, splits_size, n_splits, axis=0): the_split = Split(n_splits) return the_split(x, axis, splits_size) @@ -4302,6 +4358,7 @@ def ix_(*args): "join", "split", "transpose", + "matrix_transpose", "extract_constant", "default", "tensor_copy", diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 717a7af884..ea83d9356a 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,7 +2,7 @@ from typing import cast from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes +from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle @@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool: return False -def _T(x: TensorVariable) -> TensorVariable: - """Matrix transpose for potentially higher dimensionality tensors""" - return swapaxes(x, -1, -2) - - @register_canonicalize @node_rewriter([DimShuffle]) def transinv_to_invtrans(fgraph, node): @@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node): ): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: - return [_T(solve(x, _T(l)))] + return [solve(x, (l.mT)).mT] else: - return [_T(solve(_T(x), _T(l)))] + return [solve((x.mT), (l.mT)).mT] @register_stabilize @@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node): # __if__ no other Op makes use of the L matrix during the # stabilization Li_b = solve_triangular(L, b, lower=True, b_ndim=2) - x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2) + x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) return [x] diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index c1dc3d2de3..6100108380 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -232,6 +232,10 @@ def __trunc__(self): def T(self): return pt.basic.transpose(self) + @property + def mT(self): + return pt.basic.matrix_transpose(self) + def transpose(self, *axes): """Transpose this array. diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index d96dc3fd0c..0f161760bd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3813,6 +3813,7 @@ def test_transpose(): ) t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v) + assert t1.shape == np.transpose(x1v).shape assert t2.shape == np.transpose(x2v).shape assert t3.shape == np.transpose(x3v).shape @@ -3838,6 +3839,23 @@ def test_transpose(): assert ptb.transpose(dmatrix()).name is None +def test_matrix_transpose(): + with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"): + ptb.matrix_transpose(dvector("x1")) + + x2 = dmatrix("x2") + x3 = dtensor3("x3") + + var1 = ptb.matrix_transpose(x2) + expected_var1 = swapaxes(x2, -1, -2) + + var2 = x3.mT + expected_var2 = swapaxes(x3, -1, -2) + + assert equal_computations([var1], [expected_var1]) + assert equal_computations([var2], [expected_var2]) + + def test_stacklists(): a, b, c, d = map(scalar, "abcd") X = stacklists([[a, b], [c, d]])