Skip to content

Commit 40ace8f

Browse files
committed
fix an error for moveaxis
1 parent b050e8c commit 40ace8f

File tree

2 files changed

+99
-43
lines changed

2 files changed

+99
-43
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -785,21 +785,21 @@ def unstack(X, axis=0):
785785
return tuple(Y[i] for i in range(Y.shape[0]))
786786

787787

788-
def moveaxis(X, src, dst):
789-
"""moveaxis(x, src, dst)
788+
def moveaxis(X, source, destination):
789+
"""moveaxis(x, source, destination)
790790
791791
Moves axes of an array to new positions.
792792
793793
Args:
794794
x (usm_ndarray): input array
795795
796-
src (int or a sequence of int):
796+
source (int or a sequence of int):
797797
Original positions of the axes to move.
798798
These must be unique. If `x` has rank (i.e., number of
799799
dimensions) `N`, a valid `axis` must be in the
800800
half-open interval `[-N, N)`.
801801
802-
dst (int or a sequence of int):
802+
destination (int or a sequence of int):
803803
Destination positions for each of the original axes.
804804
These must also be unique. If `x` has rank
805805
(i.e., number of dimensions) `N`, a valid `axis` must be
@@ -818,18 +818,25 @@ def moveaxis(X, src, dst):
818818
if not isinstance(X, dpt.usm_ndarray):
819819
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
820820

821-
if not isinstance(src, (tuple, list)):
822-
src = (src,)
821+
if not isinstance(source, (tuple, list)):
822+
source = (source,)
823823

824-
if not isinstance(dst, (tuple, list)):
825-
dst = (dst,)
824+
if not isinstance(destination, (tuple, list)):
825+
destination = (destination,)
826826

827-
src = normalize_axis_tuple(src, X.ndim, "src")
828-
dst = normalize_axis_tuple(dst, X.ndim, "dst")
829-
ind = list(range(0, X.ndim))
830-
for i in range(len(src)):
831-
ind.remove(src[i]) # using the value here which is the same as index
832-
ind.insert(dst[i], src[i])
827+
source = normalize_axis_tuple(source, X.ndim, "source")
828+
destination = normalize_axis_tuple(destination, X.ndim, "destination")
829+
830+
if len(source) != len(destination):
831+
raise ValueError(
832+
"`source` and `destination` arguments must have "
833+
"the same number of elements"
834+
)
835+
836+
ind = [n for n in range(X.ndim) if n not in source]
837+
838+
for src, dst in sorted(zip(destination, source)):
839+
ind.insert(src, dst)
833840

834841
return dpt.permute_dims(X, tuple(ind))
835842

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import pytest
20-
from numpy.testing import assert_array_equal
20+
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
@@ -1068,34 +1068,83 @@ def test_swapaxes_2d():
10681068
assert_array_equal(exp, dpt.asnumpy(res))
10691069

10701070

1071-
def test_moveaxis_1axis():
1072-
x = np.arange(60).reshape((3, 4, 5))
1073-
exp = np.moveaxis(x, 0, -1)
1074-
1075-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1076-
res = dpt.moveaxis(y, 0, -1)
1077-
1078-
assert_array_equal(exp, dpt.asnumpy(res))
1079-
1080-
1081-
def test_moveaxis_2axes():
1082-
x = np.arange(60).reshape((3, 4, 5))
1083-
exp = np.moveaxis(x, [0, 1], [-1, -2])
1084-
1085-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1086-
res = dpt.moveaxis(y, [0, 1], [-1, -2])
1087-
1088-
assert_array_equal(exp, dpt.asnumpy(res))
1089-
1090-
1091-
def test_moveaxis_3axes():
1092-
x = np.arange(60).reshape((3, 4, 5))
1093-
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
1094-
1095-
y = dpt.reshape(dpt.arange(60), (3, 4, 5))
1096-
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])
1097-
1098-
assert_array_equal(exp, dpt.asnumpy(res))
1071+
def test_moveaxis_move_to_end():
1072+
x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7))
1073+
for source, expected in [
1074+
(0, (6, 7, 5)),
1075+
(1, (5, 7, 6)),
1076+
(2, (5, 6, 7)),
1077+
(-1, (5, 6, 7)),
1078+
]:
1079+
actual = dpt.moveaxis(x, source, -1).shape
1080+
assert_(actual, expected)
1081+
1082+
1083+
def test_moveaxis_new_position():
1084+
x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4))
1085+
for source, destination, expected in [
1086+
(0, 1, (2, 1, 3, 4)),
1087+
(1, 2, (1, 3, 2, 4)),
1088+
(1, -1, (1, 3, 4, 2)),
1089+
]:
1090+
actual = dpt.moveaxis(x, source, destination).shape
1091+
assert_(actual, expected)
1092+
1093+
1094+
def test_moveaxis_preserve_order():
1095+
x = dpt.zeros((1, 2, 3, 4))
1096+
for source, destination in [
1097+
(0, 0),
1098+
(3, -1),
1099+
(-1, 3),
1100+
([0, -1], [0, -1]),
1101+
([2, 0], [2, 0]),
1102+
]:
1103+
actual = dpt.moveaxis(x, source, destination).shape
1104+
assert_(actual, (1, 2, 3, 4))
1105+
1106+
1107+
def test_moveaxis_move_multiples():
1108+
x = dpt.zeros((0, 1, 2, 3))
1109+
for source, destination, expected in [
1110+
([0, 1], [2, 3], (2, 3, 0, 1)),
1111+
([2, 3], [0, 1], (2, 3, 0, 1)),
1112+
([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)),
1113+
([3, 0], [1, 0], (0, 3, 1, 2)),
1114+
([0, 3], [0, 1], (0, 3, 1, 2)),
1115+
]:
1116+
actual = dpt.moveaxis(x, source, destination).shape
1117+
assert_(actual, expected)
1118+
1119+
1120+
def test_moveaxis_errors():
1121+
x = dpt.reshape(dpt.arange(6), (1, 2, 3))
1122+
assert_raises_regex(
1123+
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0
1124+
)
1125+
assert_raises_regex(
1126+
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0
1127+
)
1128+
assert_raises_regex(
1129+
np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5
1130+
)
1131+
assert_raises_regex(
1132+
ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1]
1133+
)
1134+
assert_raises_regex(
1135+
ValueError,
1136+
"repeated axis in `destination`",
1137+
dpt.moveaxis,
1138+
x,
1139+
[0, 1],
1140+
[1, 1],
1141+
)
1142+
assert_raises_regex(
1143+
ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1]
1144+
)
1145+
assert_raises_regex(
1146+
ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0]
1147+
)
10991148

11001149

11011150
def test_unstack_axis0():

0 commit comments

Comments
 (0)