From fcc0cd516cbcaa5007c94737c4b24376b08c1eb8 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Mon, 30 Aug 2021 12:54:35 -0700 Subject: [PATCH] Fix bug in reverseV2 for 1D tensors Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 5 +++++ tf2onnx/onnx_opset/tensor.py | 23 ++++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 2a92ec03f..df61dec3c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3273,6 +3273,11 @@ def test_reversev2_1D_tensor(self): # Adds an identity block. x_val_shape = [4] x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32) + def func(x): + x_ = reverse_v2(x, axis=[0]) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + def func(x): x_ = reverse_v2(x, axis=[]) return tf.identity(x_, name=_TFOUTPUT) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 439d15fd0..10f2d50cf 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -2112,20 +2112,29 @@ def version_10(cls, ctx, node, **kwargs): axes = axes.tolist() len_axes = len(axes) + input_rank = ctx.get_rank(node.input[0]) + utils.make_sure(input_rank is not None, "rank of {} is unknown".format(node.input[0])) + needs_squeeze = False + if input_rank == 1 and len_axes != 0: + # ReverseSequence node requires rank >= 2 + utils.make_sure(axes in [[-1], [0]], "Invalid value %s for axes of ReverseV2 of 1d tensor", axes) + axes = [1] + new_inp = GraphBuilder(ctx).make_unsqueeze({'data': node.input[0], 'axes': [0]}) + ctx.replace_input(node, node.input[0], new_inp, 0) + input_rank = 2 + needs_squeeze = True + # Store input and output parameters of the ReverseV2 node. rv2_in_names = [node.input[0]] - input_shape = ctx.get_shape(node.input[0]) - input_rank = len(input_shape) input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name) - # Make sure input shape is not None - utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0])) - rv2_node_name = node.name # ReverseV2 has a single output. rv2_output_dtypes = node.output_dtypes rv2_output_shapes = node.output_shapes + if needs_squeeze and rv2_output_shapes is not None: + rv2_output_shapes[0] = [1] + rv2_output_shapes[0] # Remove ReverseV2 node from graph. ctx.remove_node(rv2_node_name) @@ -2243,6 +2252,10 @@ def version_10(cls, ctx, node, **kwargs): attr={"perm": curr_perm} ) + if needs_squeeze: + sq_node = GraphBuilder(ctx).make_squeeze({"data": node.output[0], "axes": [0]}, return_node=True) + ctx.insert_node_on_output(sq_node) + @tf_op("Unique", onnx_op="Unique") class Unique: