Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,26 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14")
@check_opset_min_version(12, "GatherND with batch_dims")
def test_gather_batch_dims_no_trans(self):
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
def func(x, idx):
x_ = tf.gather(x, idx, batch_dims=2, axis=2)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})

@check_tf_min_version("1.14")
@check_opset_min_version(12, "GatherND with batch_dims")
def test_gather_batch_dims(self):
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
def func(x, idx):
x_ = tf.gather(x, idx, batch_dims=2, axis=3)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})

@check_opset_min_version(10, "Slice")
def test_roll_axis_scalar(self):
x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2))
Expand Down
39 changes: 39 additions & 0 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ class GatherV2:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# for GatherV2 axis come as input
err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
node.type = "Gather"
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, node.input[2], 2)
node.set_attr("axis", axis)
Expand All @@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)

@classmethod
def version_12(cls, ctx, node, **kwargs):
batch_dims = node.get_attr_value("batch_dims", 0)
if batch_dims == 0:
cls.version_1(ctx, node, **kwargs)
return
# If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
data_inp, indices_inp, axis_inp = node.input
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
axis = node.inputs[2].get_tensor_value()
ctx.remove_input(node, axis_inp, 2)
if ctx.get_dtype(indices_inp) != TensorProto.INT64:
indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0]
unperm = None
# GatherND doesn't take an axis so we have to transpose stuff around
if axis != batch_dims:
data_rank = ctx.get_rank(data_inp)
indices_rank = ctx.get_rank(indices_inp)
result_rank = data_rank + indices_rank - 1 - batch_dims
shift_amt = axis - batch_dims
err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
utils.make_sure(data_rank is not None and indices_rank is not None, err_msg)
perm = list(range(data_rank))
perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:]
data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0]
ctx.replace_input(node, node.input[0], data_inp, 0)
unperm = list(range(result_rank))
j = indices_rank+shift_amt
unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:]
node.type = "GatherND"
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]})
ctx.replace_input(node, node.input[1], unsqueeze_node, 1)
if unperm is not None:
ctx.update_node_shape_dtype(node, override=True)
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm)


def _make_gathernd_inner_loop(ctx, params, index, dtype):
"""create the inner loop for GatherNd."""
Expand Down