Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,17 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
def run_onnxruntime(self, model_path, inputs, output_names):
"""Run test against onnxruntime backend."""
import onnxruntime as rt
providers = ['CPUExecutionProvider']
if rt.get_device() == "GPU":
gpus = os.environ.get("CUDA_VISIBLE_DEVICES")
if gpus is None or len(gpus) > 1:
providers = ['CUDAExecutionProvider']
opt = rt.SessionOptions()
# in case of issues with the runtime, one can enable more logging
# opt.log_severity_level = 0
# opt.log_verbosity_level = 255
# opt.enable_profiling = True
m = rt.InferenceSession(model_path, opt)
m = rt.InferenceSession(model_path, opt, providers=providers)
results = m.run(output_names, inputs)
return results

Expand Down
18 changes: 17 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0



"""Unit tests using onnx backends."""

from __future__ import division
Expand Down Expand Up @@ -2230,6 +2229,23 @@ def func(x, y):
y_val = np.array(9, dtype=np.int32)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})

@check_opset_min_version(10, "Slice")
@skip_tflite("not supported in tflite")
def test_strided_slice_ellipse(self):
def func1(x):
x_ = x[..., tf.newaxis]
return tf.identity(x_, name=_TFOUTPUT)
shape = [1, 8, 64]
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
self._run_test_case(func1, [_OUTPUT], {_INPUT: x_val})

def func2(x):
x_ = x[:, tf.newaxis, ..., :, tf.newaxis]
return tf.identity(x_, name=_TFOUTPUT)
shape = [2, 3, 4, 5]
x_val = np.arange(np.prod(shape)).astype("float32").reshape(shape)
self._run_test_case(func2, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(7, "batchnorm")
def test_fused_batchnorm(self):
x_shape = [1, 28, 28, 2]
Expand Down
32 changes: 22 additions & 10 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def version_1(cls, ctx, node, **kwargs):
ctx.copy_shape(output_name, output_cast.output[0])



@tf_op("Flatten")
class Flatten:
@classmethod
Expand Down Expand Up @@ -630,6 +629,7 @@ def version_13(cls, ctx, node, **kwargs):
# Default axis is not -1 but doesn't matter since we always set it.
cls.version_1(ctx, node, **kwargs)


@tf_op("SplitV")
class SplitV:
@classmethod
Expand Down Expand Up @@ -874,15 +874,6 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
ellipsis_mask = ellipsis_mask.i if ellipsis_mask is not None else 0
shrink_axis_mask = node.get_attr("shrink_axis_mask")
shrink_axis_mask = shrink_axis_mask.i if shrink_axis_mask is not None else 0
if new_axis_mask != 0:
unqueeze_at = []
for bit in range(32):
if (new_axis_mask >> bit) & 1 == 1:
unqueeze_at.append(bit)
begin_mask |= 1 << bit
end_mask |= 1 << bit
input_x = GraphBuilder(ctx).make_unsqueeze(
{'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True)

param_shape = ctx.get_shape(node.input[1]) or \
ctx.get_shape(node.input[2]) or \
Expand All @@ -892,6 +883,27 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
"StridedSlice op {} requires the shape of begin/end/strides".format(node.name)
)
param_rank = param_shape[0]

if new_axis_mask != 0:
unqueeze_at = []
ellipsis_gap = 0
num_new = 0
for bit in range(32):
if (new_axis_mask >> bit) & 1 == 1:
num_new += 1
if (ellipsis_mask >> bit) & 1:
input_shape = ctx.get_shape(input_x.output[0])
# calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1)
ellipsis_gap = len(input_shape) - param_rank + num_new + 1
if (new_axis_mask >> bit) & 1 == 1:
unqueeze_at.append(bit + ellipsis_gap)
begin_mask |= 1 << bit
end_mask |= 1 << bit

input_x = GraphBuilder(ctx).make_unsqueeze(
{'data': input_x.output[0], 'axes': unqueeze_at}, return_node=True)


# use in onnx graph to mask begin
new_begin_mask = [1] * param_rank
# use in onnx graph to mask end
Expand Down