|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytest
|
20 |
| -from numpy.testing import assert_array_equal |
| 20 | +from numpy.testing import assert_, assert_array_equal, assert_raises_regex |
21 | 21 |
|
22 | 22 | import dpctl
|
23 | 23 | import dpctl.tensor as dpt
|
@@ -1068,34 +1068,83 @@ def test_swapaxes_2d():
|
1068 | 1068 | assert_array_equal(exp, dpt.asnumpy(res))
|
1069 | 1069 |
|
1070 | 1070 |
|
1071 |
| -def test_moveaxis_1axis(): |
1072 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1073 |
| - exp = np.moveaxis(x, 0, -1) |
1074 |
| - |
1075 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1076 |
| - res = dpt.moveaxis(y, 0, -1) |
1077 |
| - |
1078 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
1079 |
| - |
1080 |
| - |
1081 |
| -def test_moveaxis_2axes(): |
1082 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1083 |
| - exp = np.moveaxis(x, [0, 1], [-1, -2]) |
1084 |
| - |
1085 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1086 |
| - res = dpt.moveaxis(y, [0, 1], [-1, -2]) |
1087 |
| - |
1088 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
1089 |
| - |
1090 |
| - |
1091 |
| -def test_moveaxis_3axes(): |
1092 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1093 |
| - exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) |
1094 |
| - |
1095 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1096 |
| - res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3]) |
1097 |
| - |
1098 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
| 1071 | +def test_moveaxis_move_to_end(): |
| 1072 | + x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7)) |
| 1073 | + for source, expected in [ |
| 1074 | + (0, (6, 7, 5)), |
| 1075 | + (1, (5, 7, 6)), |
| 1076 | + (2, (5, 6, 7)), |
| 1077 | + (-1, (5, 6, 7)), |
| 1078 | + ]: |
| 1079 | + actual = dpt.moveaxis(x, source, -1).shape |
| 1080 | + assert_(actual, expected) |
| 1081 | + |
| 1082 | + |
| 1083 | +def test_moveaxis_new_position(): |
| 1084 | + x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4)) |
| 1085 | + for source, destination, expected in [ |
| 1086 | + (0, 1, (2, 1, 3, 4)), |
| 1087 | + (1, 2, (1, 3, 2, 4)), |
| 1088 | + (1, -1, (1, 3, 4, 2)), |
| 1089 | + ]: |
| 1090 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1091 | + assert_(actual, expected) |
| 1092 | + |
| 1093 | + |
| 1094 | +def test_moveaxis_preserve_order(): |
| 1095 | + x = dpt.zeros((1, 2, 3, 4)) |
| 1096 | + for source, destination in [ |
| 1097 | + (0, 0), |
| 1098 | + (3, -1), |
| 1099 | + (-1, 3), |
| 1100 | + ([0, -1], [0, -1]), |
| 1101 | + ([2, 0], [2, 0]), |
| 1102 | + ]: |
| 1103 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1104 | + assert_(actual, (1, 2, 3, 4)) |
| 1105 | + |
| 1106 | + |
| 1107 | +def test_moveaxis_move_multiples(): |
| 1108 | + x = dpt.zeros((0, 1, 2, 3)) |
| 1109 | + for source, destination, expected in [ |
| 1110 | + ([0, 1], [2, 3], (2, 3, 0, 1)), |
| 1111 | + ([2, 3], [0, 1], (2, 3, 0, 1)), |
| 1112 | + ([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)), |
| 1113 | + ([3, 0], [1, 0], (0, 3, 1, 2)), |
| 1114 | + ([0, 3], [0, 1], (0, 3, 1, 2)), |
| 1115 | + ]: |
| 1116 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1117 | + assert_(actual, expected) |
| 1118 | + |
| 1119 | + |
| 1120 | +def test_moveaxis_errors(): |
| 1121 | + x = dpt.reshape(dpt.arange(6), (1, 2, 3)) |
| 1122 | + assert_raises_regex( |
| 1123 | + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0 |
| 1124 | + ) |
| 1125 | + assert_raises_regex( |
| 1126 | + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0 |
| 1127 | + ) |
| 1128 | + assert_raises_regex( |
| 1129 | + np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5 |
| 1130 | + ) |
| 1131 | + assert_raises_regex( |
| 1132 | + ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1] |
| 1133 | + ) |
| 1134 | + assert_raises_regex( |
| 1135 | + ValueError, |
| 1136 | + "repeated axis in `destination`", |
| 1137 | + dpt.moveaxis, |
| 1138 | + x, |
| 1139 | + [0, 1], |
| 1140 | + [1, 1], |
| 1141 | + ) |
| 1142 | + assert_raises_regex( |
| 1143 | + ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1] |
| 1144 | + ) |
| 1145 | + assert_raises_regex( |
| 1146 | + ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0] |
| 1147 | + ) |
1099 | 1148 |
|
1100 | 1149 |
|
1101 | 1150 | def test_unstack_axis0():
|
|
0 commit comments