Skip to content

Commit cdf9d59

Browse files
committed
Parameterize axes for test_transpose_slice
1 parent 0c0a59c commit cdf9d59

File tree

1 file changed

+66
-69
lines changed

1 file changed

+66
-69
lines changed

tests/test_optimizers.py

Lines changed: 66 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import unicode_literals
99

1010
import unittest
11-
import itertools
1211
import numpy as np
1312
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
1413
from parameterized import parameterized
@@ -311,84 +310,82 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
311310
model_proto, remaining_transpose_num=0)
312311

313312
@parameterized.expand([
314-
([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
315-
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
313+
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
314+
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
315+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
316+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
317+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
318+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
316319
])
317320
@check_opset_max_version(9, "Slice in opset 9 and takes 'axes, 'start' and 'ends' as attributes")
318-
def test_transpose_slice(self, input_shape, slice_size, perm_input, perm_output):
319-
axes_combinations = []
320-
axes = list(range(len(input_shape)))
321-
for i in range(1, len(input_shape) + 1):
322-
axes_combinations.extend(list(itertools.combinations(axes, i)))
323-
for axes in axes_combinations:
324-
axes = np.array(list(axes), dtype=np.int64)
325-
starts = np.array([0] * axes.size, dtype=np.int64)
326-
ends = []
327-
for i in range(axes.size):
328-
ends.append(slice_size[axes[i]])
329-
ends = np.array(ends, dtype=np.int64)
330-
output_shape = input_shape.copy()
331-
for axis in axes:
332-
output_shape[perm_input[axis]] = slice_size[axis]
333-
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
334-
node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice")
335-
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
321+
def test_transpose_slice(self, input_shape, slice_size, axes, perm_input, perm_output):
322+
axes = np.array(axes, dtype=np.int64)
323+
starts = np.array([0] * axes.size, dtype=np.int64)
324+
ends = []
325+
for i in range(axes.size):
326+
ends.append(slice_size[axes[i]])
327+
ends = np.array(ends, dtype=np.int64)
328+
output_shape = input_shape.copy()
329+
for axis in axes:
330+
output_shape[perm_input[axis]] = slice_size[axis]
331+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
332+
node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice")
333+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
336334

337-
graph = helper.make_graph(
338-
[node1, node2, node3],
339-
"slice-test",
340-
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
341-
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
342-
[
343-
helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts),
344-
helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
345-
helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)
346-
]
347-
)
335+
graph = helper.make_graph(
336+
[node1, node2, node3],
337+
"slice-test",
338+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
339+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
340+
[
341+
helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts),
342+
helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
343+
helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)
344+
]
345+
)
348346

349-
model_proto = self.make_model(graph, producer_name="onnx-tests")
350-
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
351-
model_proto, remaining_transpose_num=0)
347+
model_proto = self.make_model(graph, producer_name="onnx-tests")
348+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
349+
model_proto, remaining_transpose_num=0)
352350

353351
@parameterized.expand([
354-
([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
355-
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
352+
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
353+
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
354+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
355+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
356+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
357+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
356358
])
357359
@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'start' and 'ends'")
358-
def test_transpose_slice_opset_10(self, input_shape, slice_size, perm_input, perm_output):
359-
axes_combinations = []
360-
axes = list(range(len(input_shape)))
361-
for i in range(1, len(input_shape) + 1):
362-
axes_combinations.extend(list(itertools.combinations(axes, i)))
363-
for axes in axes_combinations:
364-
axes = np.array(list(axes), dtype=np.int32)
365-
starts = np.array([0] * axes.size, dtype=np.int32)
366-
ends = []
367-
for i in range(axes.size):
368-
ends.append(slice_size[axes[i]])
369-
ends = np.array(ends, dtype=np.int32)
370-
output_shape = input_shape.copy()
371-
for axis in axes:
372-
output_shape[perm_input[axis]] = slice_size[axis]
373-
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
374-
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice")
375-
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
360+
def test_transpose_slice_opset_10(self, input_shape, slice_size, axes, perm_input, perm_output):
361+
axes = np.array(axes, dtype=np.int32)
362+
starts = np.array([0] * axes.size, dtype=np.int32)
363+
ends = []
364+
for i in range(axes.size):
365+
ends.append(slice_size[axes[i]])
366+
ends = np.array(ends, dtype=np.int32)
367+
output_shape = input_shape.copy()
368+
for axis in axes:
369+
output_shape[perm_input[axis]] = slice_size[axis]
370+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
371+
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice")
372+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
376373

377-
graph = helper.make_graph(
378-
[node1, node2, node3],
379-
"slice-test",
380-
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
381-
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
382-
[
383-
helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts),
384-
helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends),
385-
helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes)
386-
]
387-
)
374+
graph = helper.make_graph(
375+
[node1, node2, node3],
376+
"slice-test",
377+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
378+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
379+
[
380+
helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts),
381+
helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends),
382+
helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes)
383+
]
384+
)
388385

389-
model_proto = self.make_model(graph, producer_name="onnx-tests")
390-
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
391-
model_proto, remaining_transpose_num=0)
386+
model_proto = self.make_model(graph, producer_name="onnx-tests")
387+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
388+
model_proto, remaining_transpose_num=0)
392389

393390
@parameterized.expand([
394391
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),

0 commit comments

Comments
 (0)