From da12f6a8f4f78f94834863786de2f3619711d69c Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 4 Feb 2021 17:54:12 -0800 Subject: [PATCH 1/2] fix StridedSlice for ellipsis+newaxis Signed-off-by: Guenther Schmuelling --- tests/backend_test_base.py | 7 ++++++- tests/test_backend.py | 17 ++++++++++++++++- tf2onnx/onnx_opset/tensor.py | 32 ++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index 567774fe8..e46da40b2 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -76,12 +76,17 @@ def run_onnxcaffe2(self, onnx_graph, inputs): def run_onnxruntime(self, model_path, inputs, output_names): """Run test against onnxruntime backend.""" import onnxruntime as rt + providers = ['CPUExecutionProvider'] + if rt.get_device() == "GPU": + gpus = os.environ.get("CUDA_VISIBLE_DEVICES") + if gpus is None or len(gpus) > 1: + providers = ['CUDAExecutionProvider'] opt = rt.SessionOptions() # in case of issues with the runtime, one can enable more logging # opt.log_severity_level = 0 # opt.log_verbosity_level = 255 # opt.enable_profiling = True - m = rt.InferenceSession(model_path, opt) + m = rt.InferenceSession(model_path, opt, providers=providers) results = m.run(output_names, inputs) return results diff --git a/tests/test_backend.py b/tests/test_backend.py index c69ebd019..3b9bb0e79 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - """Unit tests using onnx backends.""" from __future__ import division @@ -2230,6 +2229,22 @@ def func(x, y): y_val = np.array(9, dtype=np.int32) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + @check_opset_min_version(10, "Slice") + def test_strided_slice_ellipse(self): + def func1(x): + x_ = x[..., tf.newaxis] + return tf.identity(x_, name=_TFOUTPUT) + shape = [1, 8, 64] + x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape) + self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val}) + + def func2(x): + x_ = x[:, tf.newaxis, ..., :, tf.newaxis] + return tf.identity(x_, name=_TFOUTPUT) + shape = [2, 3, 4, 5] + x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape) + self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(7, "batchnorm") def test_fused_batchnorm(self): x_shape = [1, 28, 28, 2] diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 52b352369..4f7dd233a 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -72,7 +72,6 @@ def version_1(cls, ctx, node, **kwargs): ctx.copy_shape(output_name, output_cast.output[0]) - @tf_op("Flatten") class Flatten: @classmethod @@ -630,6 +629,7 @@ def version_13(cls, ctx, node, **kwargs): # Default axis is not -1 but doesn't matter since we always set it. cls.version_1(ctx, node, **kwargs) + @tf_op("SplitV") class SplitV: @classmethod @@ -874,15 +874,6 @@ def any_version_after10(cls, opset, ctx, node, **kwargs): ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0 shrink_axis_mask = node.get_attr("shrink_axis_mask") shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0 - if new_axis_mask != 0: - unqueeze_at = [] - for bit in range(32): - if (new_axis_mask >> bit) & 1 == 1: - unqueeze_at.append(bit) - begin_mask |= 1 << bit - end_mask |= 1 << bit - input_x = GraphBuilder(ctx).make_unsqueeze( - {'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True) param_shape = ctx.get_shape(node.input[1]) or \ ctx.get_shape(node.input[2]) or \ @@ -892,6 +883,27 @@ def any_version_after10(cls, opset, ctx, node, **kwargs): "StridedSlice op {} requires the shape of begin/end/strides".format(node.name) ) param_rank = param_shape[0] + + if new_axis_mask != 0: + unqueeze_at = [] + ellipsis_gap = 0 + num_new = 0 + for bit in range(32): + if (new_axis_mask >> bit) & 1 == 1: + num_new += 1 + if (ellipsis_mask >> bit) & 1: + input_shape = ctx.get_shape(input_x.output[0]) + # calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1) + ellipsis_gap = len(input_shape) - param_rank + num_new + 1 + if (new_axis_mask >> bit) & 1 == 1: + unqueeze_at.append(bit + ellipsis_gap) + begin_mask |= 1 << bit + end_mask |= 1 << bit + + input_x = GraphBuilder(ctx).make_unsqueeze( + {'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True) + + # use in onnx graph to mask begin new_begin_mask = [1] * param_rank # use in onnx graph to mask end From 4fc2abeddf45d863aba1935511196d35f1d583e4 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 5 Feb 2021 09:09:35 -0800 Subject: [PATCH 2/2] skip ut for tflite Signed-off-by: Guenther Schmuelling --- tests/test_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 3b9bb0e79..b9f3d7a3e 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2230,6 +2230,7 @@ def func(x, y): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) @check_opset_min_version(10, "Slice") + @skip_tflite("not supported in tflite") def test_strided_slice_ellipse(self): def func1(x): x_ = x[..., tf.newaxis]