|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytest
|
20 |
| -from numpy.testing import assert_array_equal |
| 20 | +from numpy.testing import assert_, assert_array_equal, assert_raises_regex |
21 | 21 |
|
22 | 22 | import dpctl
|
23 | 23 | import dpctl.tensor as dpt
|
@@ -1068,34 +1068,95 @@ def test_swapaxes_2d():
|
1068 | 1068 | assert_array_equal(exp, dpt.asnumpy(res))
|
1069 | 1069 |
|
1070 | 1070 |
|
1071 |
| -def test_moveaxis_1axis(): |
1072 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1073 |
| - exp = np.moveaxis(x, 0, -1) |
1074 |
| - |
1075 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1076 |
| - res = dpt.moveaxis(y, 0, -1) |
1077 |
| - |
1078 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
| 1071 | +@pytest.mark.parametrize( |
| 1072 | + "source, expected", |
| 1073 | + [ |
| 1074 | + (0, (6, 7, 5)), |
| 1075 | + (1, (5, 7, 6)), |
| 1076 | + (2, (5, 6, 7)), |
| 1077 | + (-1, (5, 6, 7)), |
| 1078 | + ], |
| 1079 | +) |
| 1080 | +def test_moveaxis_move_to_end(source, expected): |
| 1081 | + x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7)) |
| 1082 | + actual = dpt.moveaxis(x, source, -1).shape |
| 1083 | + assert_(actual, expected) |
1079 | 1084 |
|
1080 | 1085 |
|
1081 |
| -def test_moveaxis_2axes(): |
1082 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1083 |
| - exp = np.moveaxis(x, [0, 1], [-1, -2]) |
| 1086 | +@pytest.mark.parametrize( |
| 1087 | + "source, destination, expected", |
| 1088 | + [ |
| 1089 | + (0, 1, (2, 1, 3, 4)), |
| 1090 | + (1, 2, (1, 3, 2, 4)), |
| 1091 | + (1, -1, (1, 3, 4, 2)), |
| 1092 | + ], |
| 1093 | +) |
| 1094 | +def test_moveaxis_new_position(source, destination, expected): |
| 1095 | + x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4)) |
| 1096 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1097 | + assert_(actual, expected) |
1084 | 1098 |
|
1085 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1086 |
| - res = dpt.moveaxis(y, [0, 1], [-1, -2]) |
1087 | 1099 |
|
1088 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
| 1100 | +@pytest.mark.parametrize( |
| 1101 | + "source, destination", |
| 1102 | + [ |
| 1103 | + (0, 0), |
| 1104 | + (3, -1), |
| 1105 | + (-1, 3), |
| 1106 | + ([0, -1], [0, -1]), |
| 1107 | + ([2, 0], [2, 0]), |
| 1108 | + ], |
| 1109 | +) |
| 1110 | +def test_moveaxis_preserve_order(source, destination): |
| 1111 | + x = dpt.zeros((1, 2, 3, 4)) |
| 1112 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1113 | + assert_(actual, (1, 2, 3, 4)) |
1089 | 1114 |
|
1090 | 1115 |
|
1091 |
| -def test_moveaxis_3axes(): |
1092 |
| - x = np.arange(60).reshape((3, 4, 5)) |
1093 |
| - exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) |
| 1116 | +@pytest.mark.parametrize( |
| 1117 | + "source, destination, expected", |
| 1118 | + [ |
| 1119 | + ([0, 1], [2, 3], (2, 3, 0, 1)), |
| 1120 | + ([2, 3], [0, 1], (2, 3, 0, 1)), |
| 1121 | + ([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)), |
| 1122 | + ([3, 0], [1, 0], (0, 3, 1, 2)), |
| 1123 | + ([0, 3], [0, 1], (0, 3, 1, 2)), |
| 1124 | + ], |
| 1125 | +) |
| 1126 | +def test_moveaxis_move_multiples(source, destination, expected): |
| 1127 | + x = dpt.zeros((0, 1, 2, 3)) |
| 1128 | + actual = dpt.moveaxis(x, source, destination).shape |
| 1129 | + assert_(actual, expected) |
1094 | 1130 |
|
1095 |
| - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) |
1096 |
| - res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3]) |
1097 | 1131 |
|
1098 |
| - assert_array_equal(exp, dpt.asnumpy(res)) |
| 1132 | +def test_moveaxis_errors(): |
| 1133 | + x = dpt.reshape(dpt.arange(6), (1, 2, 3)) |
| 1134 | + assert_raises_regex( |
| 1135 | + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0 |
| 1136 | + ) |
| 1137 | + assert_raises_regex( |
| 1138 | + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0 |
| 1139 | + ) |
| 1140 | + assert_raises_regex( |
| 1141 | + np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5 |
| 1142 | + ) |
| 1143 | + assert_raises_regex( |
| 1144 | + ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1] |
| 1145 | + ) |
| 1146 | + assert_raises_regex( |
| 1147 | + ValueError, |
| 1148 | + "repeated axis in `destination`", |
| 1149 | + dpt.moveaxis, |
| 1150 | + x, |
| 1151 | + [0, 1], |
| 1152 | + [1, 1], |
| 1153 | + ) |
| 1154 | + assert_raises_regex( |
| 1155 | + ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1] |
| 1156 | + ) |
| 1157 | + assert_raises_regex( |
| 1158 | + ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0] |
| 1159 | + ) |
1099 | 1160 |
|
1100 | 1161 |
|
1101 | 1162 | def test_unstack_axis0():
|
|
0 commit comments