From 4ada292aef748da2561225065674b42d2f05ba5b Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 9 Feb 2021 17:04:29 -0500 Subject: [PATCH 1/2] Add support for GatherV2 batch_dims attr Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 18 +++++++++++++++++ tf2onnx/onnx_opset/tensor.py | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index b7fca7dd6..cf97523ff 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -871,6 +871,24 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(12, "GatherND with batch_dims") + def test_gather_batch_dims_no_trans(self): + x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4)) + idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32) + def func(x, idx): + x_ = tf.gather(x, idx, batch_dims=2, axis=2) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val}) + + @check_opset_min_version(12, "GatherND with batch_dims") + def test_gather_batch_dims(self): + x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4)) + idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32) + def func(x, idx): + x_ = tf.gather(x, idx, batch_dims=2, axis=3) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val}) + @check_opset_min_version(10, "Slice") def test_roll_axis_scalar(self): x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2)) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index 4f7dd233a..8fbfa11ae 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -423,7 +423,10 @@ class GatherV2: @classmethod def version_1(cls, ctx, node, **kwargs): # for GatherV2 axis come as input + err_msg = "Opset 12 required for batch_dims attribute of GatherV2" + utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg) node.type = "Gather" + utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant") axis = node.inputs[2].get_tensor_value() ctx.remove_input(node, node.input[2], 2) node.set_attr("axis", axis) @@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs): # no change cls.version_1(ctx, node, **kwargs) + @classmethod + def version_12(cls, ctx, node, **kwargs): + batch_dims = node.get_attr_value("batch_dims", 0) + if batch_dims == 0: + cls.version_1(ctx, node, **kwargs) + return + # If batch_dims is not zero, use GatherND to simulate Gather with batch dims. + data_inp, indices_inp, axis_inp = node.input + utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant") + axis = node.inputs[2].get_tensor_value() + ctx.remove_input(node, axis_inp, 2) + if ctx.get_dtype(indices_inp) != TensorProto.INT64: + indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0] + unperm = None + # GatherND doesn't take an axis so we have to transpose stuff around + if axis != batch_dims: + data_rank = ctx.get_rank(data_inp) + indices_rank = ctx.get_rank(indices_inp) + result_rank = data_rank + indices_rank - 1 - batch_dims + shift_amt = axis - batch_dims + err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks." + utils.make_sure(data_rank is not None and indices_rank is not None, err_msg) + perm = list(range(data_rank)) + perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:] + data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0] + ctx.replace_input(node, node.input[0], data_inp, 0) + unperm = list(range(result_rank)) + j = indices_rank+shift_amt + unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:] + node.type = "GatherND" + unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]}) + ctx.replace_input(node, node.input[1], unsqueeze_node, 1) + if unperm is not None: + ctx.update_node_shape_dtype(node, override=True) + ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm) + def _make_gathernd_inner_loop(ctx, params, index, dtype): """create the inner loop for GatherNd.""" From 6f2b4fa7517cda0c8c1d3fe30d3d3e435322b71a Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 9 Feb 2021 17:24:42 -0500 Subject: [PATCH 2/2] Fix tests Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index cf97523ff..8b972380d 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -871,6 +871,7 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("1.14") @check_opset_min_version(12, "GatherND with batch_dims") def test_gather_batch_dims_no_trans(self): x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4)) @@ -880,6 +881,7 @@ def func(x, idx): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val}) + @check_tf_min_version("1.14") @check_opset_min_version(12, "GatherND with batch_dims") def test_gather_batch_dims(self): x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))