|
8 | 8 | from __future__ import unicode_literals
|
9 | 9 |
|
10 | 10 | import unittest
|
| 11 | +import itertools |
11 | 12 | import numpy as np
|
12 | 13 | from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
|
13 | 14 | from parameterized import parameterized
|
| 15 | + |
14 | 16 | from backend_test_base import Tf2OnnxBackendTestBase
|
15 | 17 | from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
|
16 | 18 | from tf2onnx import utils, constants
|
@@ -309,33 +311,84 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
|
309 | 311 | model_proto, remaining_transpose_num=0)
|
310 | 312 |
|
311 | 313 | @parameterized.expand([
|
312 |
| - ((2, 3, 4, 5), [1, 2, 1, 2], (1, 2, 2, 1), [0, 2, 3, 1], [0, 3, 1, 2]), |
313 |
| - ((2, 3, 4, 5, 6), [1, 2, 1, 2, 1], (1, 1, 2, 1, 2), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]), |
| 314 | + ([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]), |
| 315 | + ([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]), |
314 | 316 | ])
|
315 |
| - @check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'") |
316 |
| - def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input, perm_output): |
317 |
| - starts = np.array([0] * len(input_shape), dtype=np.int64) |
318 |
| - ends = np.array(slice_size, dtype=np.int64) |
319 |
| - axes = np.array(list(range(len(input_shape))), dtype=np.int64) |
320 |
| - node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1") |
321 |
| - node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="relu") |
322 |
| - node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2") |
| 317 | + @check_opset_max_version(9, "Slice in opset 9 and takes 'axes, 'start' and 'ends' as attributes") |
| 318 | + def test_transpose_slice(self, input_shape, slice_size, perm_input, perm_output): |
| 319 | + axes_combinations = [] |
| 320 | + axes = list(range(len(input_shape))) |
| 321 | + for i in range(1, len(input_shape) + 1): |
| 322 | + axes_combinations.extend(list(itertools.combinations(axes, i))) |
| 323 | + for axes in axes_combinations: |
| 324 | + axes = np.array(list(axes), dtype=np.int64) |
| 325 | + starts = np.array([0] * axes.size, dtype=np.int64) |
| 326 | + ends = [] |
| 327 | + for i in range(axes.size): |
| 328 | + ends.append(slice_size[axes[i]]) |
| 329 | + ends = np.array(ends, dtype=np.int64) |
| 330 | + output_shape = input_shape.copy() |
| 331 | + for axis in axes: |
| 332 | + output_shape[perm_input[axis]] = slice_size[axis] |
| 333 | + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1") |
| 334 | + node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice") |
| 335 | + node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2") |
323 | 336 |
|
324 |
| - graph = helper.make_graph( |
325 |
| - [node1, node2, node3], |
326 |
| - "relu-test", |
327 |
| - [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], |
328 |
| - [helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)], |
329 |
| - [ |
330 |
| - helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts), |
331 |
| - helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends), |
332 |
| - helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes) |
333 |
| - ] |
334 |
| - ) |
| 337 | + graph = helper.make_graph( |
| 338 | + [node1, node2, node3], |
| 339 | + "slice-test", |
| 340 | + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], |
| 341 | + [helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)], |
| 342 | + [ |
| 343 | + helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts), |
| 344 | + helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends), |
| 345 | + helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes) |
| 346 | + ] |
| 347 | + ) |
335 | 348 |
|
336 |
| - model_proto = self.make_model(graph, producer_name="onnx-tests") |
337 |
| - self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)}, |
338 |
| - model_proto, remaining_transpose_num=0) |
| 349 | + model_proto = self.make_model(graph, producer_name="onnx-tests") |
| 350 | + self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)}, |
| 351 | + model_proto, remaining_transpose_num=0) |
| 352 | + |
| 353 | + @parameterized.expand([ |
| 354 | + ([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]), |
| 355 | + ([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]), |
| 356 | + ]) |
| 357 | + @check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'start' and 'ends'") |
| 358 | + def test_transpose_slice_10(self, input_shape, slice_size, perm_input, perm_output): |
| 359 | + axes_combinations = [] |
| 360 | + axes = list(range(len(input_shape))) |
| 361 | + for i in range(1, len(input_shape) + 1): |
| 362 | + axes_combinations.extend(list(itertools.combinations(axes, i))) |
| 363 | + for axes in axes_combinations: |
| 364 | + axes = np.array(list(axes), dtype=np.int32) |
| 365 | + starts = np.array([0] * axes.size, dtype=np.int32) |
| 366 | + ends = [] |
| 367 | + for i in range(axes.size): |
| 368 | + ends.append(slice_size[axes[i]]) |
| 369 | + ends = np.array(ends, dtype=np.int32) |
| 370 | + output_shape = input_shape.copy() |
| 371 | + for axis in axes: |
| 372 | + output_shape[perm_input[axis]] = slice_size[axis] |
| 373 | + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1") |
| 374 | + node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice") |
| 375 | + node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2") |
| 376 | + |
| 377 | + graph = helper.make_graph( |
| 378 | + [node1, node2, node3], |
| 379 | + "slice-test", |
| 380 | + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], |
| 381 | + [helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)], |
| 382 | + [ |
| 383 | + helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts), |
| 384 | + helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends), |
| 385 | + helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes) |
| 386 | + ] |
| 387 | + ) |
| 388 | + |
| 389 | + model_proto = self.make_model(graph, producer_name="onnx-tests") |
| 390 | + self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)}, |
| 391 | + model_proto, remaining_transpose_num=0) |
339 | 392 |
|
340 | 393 | @parameterized.expand([
|
341 | 394 | ((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
|
|
0 commit comments