Skip to content

Commit b91d3e8

Browse files
committed
Optimize Transpose->Slice regardless of Slice's axes
Signed-off-by: Mateusz Tabaka <[email protected]>
1 parent 8aa1127 commit b91d3e8

File tree

2 files changed

+98
-43
lines changed

2 files changed

+98
-43
lines changed

tests/test_optimizers.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from __future__ import unicode_literals
99

1010
import unittest
11+
import itertools
1112
import numpy as np
1213
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
1314
from parameterized import parameterized
15+
1416
from backend_test_base import Tf2OnnxBackendTestBase
1517
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
1618
from tf2onnx import utils, constants
@@ -309,33 +311,84 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
309311
model_proto, remaining_transpose_num=0)
310312

311313
@parameterized.expand([
312-
((2, 3, 4, 5), [1, 2, 1, 2], (1, 2, 2, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
313-
((2, 3, 4, 5, 6), [1, 2, 1, 2, 1], (1, 1, 2, 1, 2), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
314+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
315+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
314316
])
315-
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
316-
def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input, perm_output):
317-
starts = np.array([0] * len(input_shape), dtype=np.int64)
318-
ends = np.array(slice_size, dtype=np.int64)
319-
axes = np.array(list(range(len(input_shape))), dtype=np.int64)
320-
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
321-
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="relu")
322-
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
317+
@check_opset_max_version(9, "Slice in opset 9 and takes 'axes, 'start' and 'ends' as attributes")
318+
def test_transpose_slice(self, input_shape, slice_size, perm_input, perm_output):
319+
axes_combinations = []
320+
axes = list(range(len(input_shape)))
321+
for i in range(1, len(input_shape) + 1):
322+
axes_combinations.extend(list(itertools.combinations(axes, i)))
323+
for axes in axes_combinations:
324+
axes = np.array(list(axes), dtype=np.int64)
325+
starts = np.array([0] * axes.size, dtype=np.int64)
326+
ends = []
327+
for i in range(axes.size):
328+
ends.append(slice_size[axes[i]])
329+
ends = np.array(ends, dtype=np.int64)
330+
output_shape = input_shape.copy()
331+
for axis in axes:
332+
output_shape[perm_input[axis]] = slice_size[axis]
333+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
334+
node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice")
335+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
323336

324-
graph = helper.make_graph(
325-
[node1, node2, node3],
326-
"relu-test",
327-
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
328-
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
329-
[
330-
helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts),
331-
helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
332-
helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)
333-
]
334-
)
337+
graph = helper.make_graph(
338+
[node1, node2, node3],
339+
"slice-test",
340+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
341+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
342+
[
343+
helper.make_tensor("starts", TensorProto.INT64, starts.shape, starts),
344+
helper.make_tensor("ends", TensorProto.INT64, ends.shape, ends),
345+
helper.make_tensor("axes", TensorProto.INT64, axes.shape, axes)
346+
]
347+
)
335348

336-
model_proto = self.make_model(graph, producer_name="onnx-tests")
337-
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
338-
model_proto, remaining_transpose_num=0)
349+
model_proto = self.make_model(graph, producer_name="onnx-tests")
350+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
351+
model_proto, remaining_transpose_num=0)
352+
353+
@parameterized.expand([
354+
([2, 3, 4, 5], [1, 2, 1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
355+
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
356+
])
357+
@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'start' and 'ends'")
358+
def test_transpose_slice_10(self, input_shape, slice_size, perm_input, perm_output):
359+
axes_combinations = []
360+
axes = list(range(len(input_shape)))
361+
for i in range(1, len(input_shape) + 1):
362+
axes_combinations.extend(list(itertools.combinations(axes, i)))
363+
for axes in axes_combinations:
364+
axes = np.array(list(axes), dtype=np.int32)
365+
starts = np.array([0] * axes.size, dtype=np.int32)
366+
ends = []
367+
for i in range(axes.size):
368+
ends.append(slice_size[axes[i]])
369+
ends = np.array(ends, dtype=np.int32)
370+
output_shape = input_shape.copy()
371+
for axis in axes:
372+
output_shape[perm_input[axis]] = slice_size[axis]
373+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
374+
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice")
375+
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")
376+
377+
graph = helper.make_graph(
378+
[node1, node2, node3],
379+
"slice-test",
380+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
381+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
382+
[
383+
helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts),
384+
helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends),
385+
helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes)
386+
]
387+
)
388+
389+
model_proto = self.make_model(graph, producer_name="onnx-tests")
390+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
391+
model_proto, remaining_transpose_num=0)
339392

340393
@parameterized.expand([
341394
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -698,25 +698,27 @@ def _slice_handler(self, trans, node):
698698
if not axes_values:
699699
return False
700700
axes = axes_values.ints
701-
if axes == list(range(trans_rank)):
702-
new_axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
703-
node.set_attr("axes", new_axes)
704-
return self._switch_transpose_and_node(node, trans)
705-
else: # in opset 10, axes is input instead of an attribute.
706-
if len(node.inputs) >= 4 and node.inputs[3].is_const():
707-
axes = node.inputs[3].get_tensor_value(as_list=True)
708-
if axes == list(range(trans_rank)):
709-
axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
710-
# axes node might be shared
711-
new_axes = np.array(axes, dtype=np.int64)
712-
if self._nodes_has_single_consumer_node([node.inputs[3]]):
713-
node.inputs[3].set_tensor_value(new_axes)
714-
else:
715-
new_axes_const = self._g.make_const(
716-
utils.make_name(node.inputs[3].name), new_axes
717-
)
718-
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
719-
return self._switch_transpose_and_node(node, trans)
701+
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
702+
new_axes = [perm[axes[i]] for i in range(len(axes))]
703+
node.set_attr("axes", new_axes)
704+
return self._switch_transpose_and_node(node, trans)
705+
# in opset 10, axes is input instead of an attribute.
706+
if len(node.inputs) >= 4 and node.inputs[3].is_const():
707+
axes = node.inputs[3].get_tensor_value(as_list=False)
708+
dtype = axes.dtype
709+
axes = axes.tolist()
710+
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
711+
axes = [perm[axes[i]] for i in range(len(axes))]
712+
# axes node might be shared
713+
new_axes = np.array(axes, dtype=dtype)
714+
if self._nodes_has_single_consumer_node([node.inputs[3]]):
715+
node.inputs[3].set_tensor_value(new_axes)
716+
else:
717+
new_axes_const = self._g.make_const(
718+
utils.make_name(node.inputs[3].name), new_axes
719+
)
720+
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
721+
return self._switch_transpose_and_node(node, trans)
720722
return False
721723

722724
def _quantize_handler(self, trans, node):

0 commit comments

Comments
 (0)