Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3273,6 +3273,11 @@ def test_reversev2_1D_tensor(self):
# Adds an identity block.
x_val_shape = [4]
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
def func(x):
x_ = reverse_v2(x, axis=[0])
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

def func(x):
x_ = reverse_v2(x, axis=[])
return tf.identity(x_, name=_TFOUTPUT)
Expand Down
23 changes: 18 additions & 5 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,20 +2112,29 @@ def version_10(cls, ctx, node, **kwargs):
axes = axes.tolist()
len_axes = len(axes)

input_rank = ctx.get_rank(node.input[0])
utils.make_sure(input_rank is not None, "rank of {} is unknown".format(node.input[0]))
needs_squeeze = False
if input_rank == 1 and len_axes != 0:
# ReverseSequence node requires rank >= 2
utils.make_sure(axes in [[-1], [0]], "Invalid value %s for axes of ReverseV2 of 1d tensor", axes)
axes = [1]
new_inp = GraphBuilder(ctx).make_unsqueeze({'data': node.input[0], 'axes': [0]})
ctx.replace_input(node, node.input[0], new_inp, 0)
input_rank = 2
needs_squeeze = True

# Store input and output parameters of the ReverseV2 node.
rv2_in_names = [node.input[0]]

input_shape = ctx.get_shape(node.input[0])
input_rank = len(input_shape)
input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)

# Make sure input shape is not None
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))

rv2_node_name = node.name
# ReverseV2 has a single output.
rv2_output_dtypes = node.output_dtypes
rv2_output_shapes = node.output_shapes
if needs_squeeze and rv2_output_shapes is not None:
rv2_output_shapes[0] = [1] + rv2_output_shapes[0]

# Remove ReverseV2 node from graph.
ctx.remove_node(rv2_node_name)
Expand Down Expand Up @@ -2243,6 +2252,10 @@ def version_10(cls, ctx, node, **kwargs):
attr={"perm": curr_perm}
)

if needs_squeeze:
sq_node = GraphBuilder(ctx).make_squeeze({"data": node.output[0], "axes": [0]}, return_node=True)
ctx.insert_node_on_output(sq_node)


@tf_op("Unique", onnx_op="Unique")
class Unique:
Expand Down