diff --git a/tests/test_backend.py b/tests/test_backend.py index fe50e0591..4e4d08969 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -5893,5 +5893,23 @@ def func(x): x_val = make_xval([3, 4]) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(10, "Slice") + def test_addition_two_newaxis_simultaneously(self): + def func(x): + op = x[..., tf.newaxis, tf.newaxis] + return tf.identity(op, name=_TFOUTPUT) + + x_val = make_xval([2, 3]) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + + @check_opset_min_version(10, "Slice") + def test_addition_three_newaxis_simultaneously(self): + def func(x): + op = x[..., tf.newaxis, tf.newaxis, tf.newaxis] + return tf.identity(op, name=_TFOUTPUT) + + x_val = make_xval([2, 3]) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + if __name__ == '__main__': unittest_main() diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 9ad2b513e..b08188b88 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -974,6 +974,29 @@ def any_version_after10(cls, opset, ctx, node, **kwargs): begin_mask |= 1 << bit end_mask |= 1 << bit + if ellipsis_mask: + unqueeze_at = [] + ellipsis_gap = 0 + num_new = 0 + end_mask = node.get_attr("end_mask") + end_mask = end_mask.i if end_mask is not None else 0 + begin_mask = node.get_attr("begin_mask") + begin_mask = begin_mask.i if begin_mask is not None else 0 + + for bit in range(32): + new_axis_flag = (new_axis_mask >> bit) & 1 + ellipsis_flag = (ellipsis_mask >> bit) & 1 + num_new += not ellipsis_flag and new_axis_flag + + for bit in range(32): + if (ellipsis_mask >> bit) & 1: + ellipsis_gap = len(ctx.get_shape(input_x)) - param_rank + num_new + 1 + elif (new_axis_mask >> bit) & 1: + effective_bit = bit if not ellipsis_gap else bit + ellipsis_gap - 1 + unqueeze_at.append(effective_bit) + begin_mask |= 1 << bit + end_mask |= 1 << bit + input_x = GraphBuilder(ctx).make_unsqueeze( {'data': input_x, 'axes': unqueeze_at})