Skip to content

Commit 193a8b1

Browse files
Merge pull request #1174 from IntelPython/fix_moveaxis_error
Fix implementation of dpctl.tensor.moveaxis()
2 parents b050e8c + 15e5e73 commit 193a8b1

File tree

2 files changed

+104
-35
lines changed

2 files changed

+104
-35
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -785,21 +785,21 @@ def unstack(X, axis=0):
785785
return tuple(Y[i] for i in range(Y.shape[0]))
786786

787787

788-
def moveaxis(X, src, dst):
789-
"""moveaxis(x, src, dst)
788+
def moveaxis(X, source, destination):
789+
"""moveaxis(x, source, destination)
790790
791791
Moves axes of an array to new positions.
792792
793793
Args:
794794
x (usm_ndarray): input array
795795
796-
src (int or a sequence of int):
796+
source (int or a sequence of int):
797797
Original positions of the axes to move.
798798
These must be unique. If `x` has rank (i.e., number of
799799
dimensions) `N`, a valid `axis` must be in the
800800
half-open interval `[-N, N)`.
801801
802-
dst (int or a sequence of int):
802+
destination (int or a sequence of int):
803803
Destination positions for each of the original axes.
804804
These must also be unique. If `x` has rank
805805
(i.e., number of dimensions) `N`, a valid `axis` must be
@@ -814,22 +814,30 @@ def moveaxis(X, src, dst):
814814
815815
Raises:
816816
AxisError: if `axis` value is invalid.
817+
ValueError: if `src` and `dst` have not equal number of elements.
817818
"""
818819
if not isinstance(X, dpt.usm_ndarray):
819820
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
820821

821-
if not isinstance(src, (tuple, list)):
822-
src = (src,)
822+
if not isinstance(source, (tuple, list)):
823+
source = (source,)
823824

824-
if not isinstance(dst, (tuple, list)):
825-
dst = (dst,)
825+
if not isinstance(destination, (tuple, list)):
826+
destination = (destination,)
826827

827-
src = normalize_axis_tuple(src, X.ndim, "src")
828-
dst = normalize_axis_tuple(dst, X.ndim, "dst")
829-
ind = list(range(0, X.ndim))
830-
for i in range(len(src)):
831-
ind.remove(src[i]) # using the value here which is the same as index
832-
ind.insert(dst[i], src[i])
828+
source = normalize_axis_tuple(source, X.ndim, "source")
829+
destination = normalize_axis_tuple(destination, X.ndim, "destination")
830+
831+
if len(source) != len(destination):
832+
raise ValueError(
833+
"`source` and `destination` arguments must have "
834+
"the same number of elements"
835+
)
836+
837+
ind = [n for n in range(X.ndim) if n not in source]
838+
839+
for src, dst in sorted(zip(destination, source)):
840+
ind.insert(src, dst)
833841

834842
return dpt.permute_dims(X, tuple(ind))
835843

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import pytest
20-
from numpy.testing import assert_array_equal
20+
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
@@ -1068,34 +1068,95 @@ def test_swapaxes_2d():
10681068
assert_array_equal(exp, dpt.asnumpy(res))
10691069

10701070

1071-
def test_moveaxis_1axis():
1072-
x = np.arange(60).reshape((3, 4, 5))
1073-
exp = np.moveaxis(x, 0, -1)
1074-
1075-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1076-
res = dpt.moveaxis(y, 0, -1)
1077-
1078-
assert_array_equal(exp, dpt.asnumpy(res))
1071+
@pytest.mark.parametrize(
1072+
"source, expected",
1073+
[
1074+
(0, (6, 7, 5)),
1075+
(1, (5, 7, 6)),
1076+
(2, (5, 6, 7)),
1077+
(-1, (5, 6, 7)),
1078+
],
1079+
)
1080+
def test_moveaxis_move_to_end(source, expected):
1081+
x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7))
1082+
actual = dpt.moveaxis(x, source, -1).shape
1083+
assert_(actual, expected)
10791084

10801085

1081-
def test_moveaxis_2axes():
1082-
x = np.arange(60).reshape((3, 4, 5))
1083-
exp = np.moveaxis(x, [0, 1], [-1, -2])
1086+
@pytest.mark.parametrize(
1087+
"source, destination, expected",
1088+
[
1089+
(0, 1, (2, 1, 3, 4)),
1090+
(1, 2, (1, 3, 2, 4)),
1091+
(1, -1, (1, 3, 4, 2)),
1092+
],
1093+
)
1094+
def test_moveaxis_new_position(source, destination, expected):
1095+
x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4))
1096+
actual = dpt.moveaxis(x, source, destination).shape
1097+
assert_(actual, expected)
10841098

1085-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1086-
res = dpt.moveaxis(y, [0, 1], [-1, -2])
10871099

1088-
assert_array_equal(exp, dpt.asnumpy(res))
1100+
@pytest.mark.parametrize(
1101+
"source, destination",
1102+
[
1103+
(0, 0),
1104+
(3, -1),
1105+
(-1, 3),
1106+
([0, -1], [0, -1]),
1107+
([2, 0], [2, 0]),
1108+
],
1109+
)
1110+
def test_moveaxis_preserve_order(source, destination):
1111+
x = dpt.zeros((1, 2, 3, 4))
1112+
actual = dpt.moveaxis(x, source, destination).shape
1113+
assert_(actual, (1, 2, 3, 4))
10891114

10901115

1091-
def test_moveaxis_3axes():
1092-
x = np.arange(60).reshape((3, 4, 5))
1093-
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
1116+
@pytest.mark.parametrize(
1117+
"source, destination, expected",
1118+
[
1119+
([0, 1], [2, 3], (2, 3, 0, 1)),
1120+
([2, 3], [0, 1], (2, 3, 0, 1)),
1121+
([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)),
1122+
([3, 0], [1, 0], (0, 3, 1, 2)),
1123+
([0, 3], [0, 1], (0, 3, 1, 2)),
1124+
],
1125+
)
1126+
def test_moveaxis_move_multiples(source, destination, expected):
1127+
x = dpt.zeros((0, 1, 2, 3))
1128+
actual = dpt.moveaxis(x, source, destination).shape
1129+
assert_(actual, expected)
10941130

1095-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1096-
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])
10971131

1098-
assert_array_equal(exp, dpt.asnumpy(res))
1132+
def test_moveaxis_errors():
1133+
x = dpt.reshape(dpt.arange(6), (1, 2, 3))
1134+
assert_raises_regex(
1135+
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0
1136+
)
1137+
assert_raises_regex(
1138+
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0
1139+
)
1140+
assert_raises_regex(
1141+
np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5
1142+
)
1143+
assert_raises_regex(
1144+
ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1]
1145+
)
1146+
assert_raises_regex(
1147+
ValueError,
1148+
"repeated axis in `destination`",
1149+
dpt.moveaxis,
1150+
x,
1151+
[0, 1],
1152+
[1, 1],
1153+
)
1154+
assert_raises_regex(
1155+
ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1]
1156+
)
1157+
assert_raises_regex(
1158+
ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0]
1159+
)
10991160

11001161

11011162
def test_unstack_axis0():

0 commit comments

Comments
 (0)