Skip to content

Commit cdabe5d

Browse files
Merge branch 'master' into tom/RaggedGather
2 parents f37a78a + 5fb9194 commit cdabe5d

File tree

2 files changed

+139
-24
lines changed

2 files changed

+139
-24
lines changed

tests/test_backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,26 @@ def func(x):
871871
return tf.identity(x_, name=_TFOUTPUT)
872872
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
873873

874+
@check_tf_min_version("1.14")
875+
@check_opset_min_version(12, "GatherND with batch_dims")
876+
def test_gather_batch_dims_no_trans(self):
877+
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
878+
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
879+
def func(x, idx):
880+
x_ = tf.gather(x, idx, batch_dims=2, axis=2)
881+
return tf.identity(x_, name=_TFOUTPUT)
882+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})
883+
884+
@check_tf_min_version("1.14")
885+
@check_opset_min_version(12, "GatherND with batch_dims")
886+
def test_gather_batch_dims(self):
887+
x_val = np.arange(2 * 2 * 3 * 5 * 4, dtype=np.float32).reshape((2, 2, 3, 5, 4))
888+
idx_val = np.array([[[1, 0, 2, 0], [1, 1, 1, 0]], [[0, 0, 0, 0], [2, 1, 1, 0]]], dtype=np.int32)
889+
def func(x, idx):
890+
x_ = tf.gather(x, idx, batch_dims=2, axis=3)
891+
return tf.identity(x_, name=_TFOUTPUT)
892+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: idx_val})
893+
874894
@check_opset_min_version(10, "Slice")
875895
def test_roll_axis_scalar(self):
876896
x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2))
@@ -3852,6 +3872,28 @@ def func(splits, rt_dense_values, indices):
38523872
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
38533873
{_INPUT: splits_val, _INPUT1: dense_vals_val, _INPUT2: indices_val})
38543874

3875+
def test_ragged_tensor_to_tensor(self):
3876+
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
3877+
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3878+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3879+
def func(splits1, splits2, rt_dense_values):
3880+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
3881+
y = x.to_tensor(default_value=7)
3882+
return tf.identity(y, name=_TFOUTPUT)
3883+
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
3884+
3885+
@check_tf_min_version("2.2", "ragged to_tensor with constrained shape")
3886+
@check_opset_min_version(11, "CumSum")
3887+
def test_ragged_tensor_to_tensor_constrain_shape(self):
3888+
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
3889+
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3890+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3891+
def func(splits1, splits2, rt_dense_values):
3892+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
3893+
y = x.to_tensor(default_value=7, shape=[20, None, 2])
3894+
return tf.identity(y, name=_TFOUTPUT)
3895+
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
3896+
38553897
@check_tf_min_version("1.14", "ragged needs tf 1.14")
38563898
@check_opset_min_version(11, "Range")
38573899
def test_ragged_range_float(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ class GatherV2:
423423
@classmethod
424424
def version_1(cls, ctx, node, **kwargs):
425425
# for GatherV2 axis come as input
426+
err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
427+
utils.make_sure(node.get_attr_value("batch_dims", 0) == 0, err_msg)
426428
node.type = "Gather"
429+
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
427430
axis = node.inputs[2].get_tensor_value()
428431
ctx.remove_input(node, node.input[2], 2)
429432
node.set_attr("axis", axis)
@@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
433436
# no change
434437
cls.version_1(ctx, node, **kwargs)
435438

439+
@classmethod
440+
def version_12(cls, ctx, node, **kwargs):
441+
batch_dims = node.get_attr_value("batch_dims", 0)
442+
if batch_dims == 0:
443+
cls.version_1(ctx, node, **kwargs)
444+
return
445+
# If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
446+
data_inp, indices_inp, axis_inp = node.input
447+
utils.make_sure(node.inputs[2].is_const(), "Axis of GatherV2 node must be constant")
448+
axis = node.inputs[2].get_tensor_value()
449+
ctx.remove_input(node, axis_inp, 2)
450+
if ctx.get_dtype(indices_inp) != TensorProto.INT64:
451+
indices_inp = ctx.make_node("Cast", [indices_inp], attr={'to': TensorProto.INT64}).output[0]
452+
unperm = None
453+
# GatherND doesn't take an axis so we have to transpose stuff around
454+
if axis != batch_dims:
455+
data_rank = ctx.get_rank(data_inp)
456+
indices_rank = ctx.get_rank(indices_inp)
457+
result_rank = data_rank + indices_rank - 1 - batch_dims
458+
shift_amt = axis - batch_dims
459+
err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
460+
utils.make_sure(data_rank is not None and indices_rank is not None, err_msg)
461+
perm = list(range(data_rank))
462+
perm = perm[:batch_dims] + perm[axis:axis+1] + perm[batch_dims:axis] + perm[axis+1:]
463+
data_inp = ctx.make_node("Transpose", [data_inp], attr={'perm': perm}).output[0]
464+
ctx.replace_input(node, node.input[0], data_inp, 0)
465+
unperm = list(range(result_rank))
466+
j = indices_rank+shift_amt
467+
unperm = unperm[:batch_dims] + unperm[indices_rank:j] + unperm[batch_dims:indices_rank] + unperm[j:]
468+
node.type = "GatherND"
469+
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': indices_inp, 'axes': [-1]})
470+
ctx.replace_input(node, node.input[1], unsqueeze_node, 1)
471+
if unperm is not None:
472+
ctx.update_node_shape_dtype(node, override=True)
473+
ctx.insert_new_node_on_output("Transpose", node.output[0], perm=unperm)
474+
436475

437476
def _make_gathernd_inner_loop(ctx, params, index, dtype):
438477
"""create the inner loop for GatherNd."""
@@ -2077,43 +2116,77 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
20772116
return num_rows, num_cols, row_indices, col_indices
20782117

20792118

2119+
def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
2120+
sparse_indices = None
2121+
dense_shape_dims = []
2122+
for split in nested_splits:
2123+
if ctx.get_dtype(split) != TensorProto.INT64:
2124+
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
2125+
max_int64 = int(utils.get_max_value(np.int64))
2126+
slice1 = GraphBuilder(ctx).make_slice(
2127+
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
2128+
slice2 = GraphBuilder(ctx).make_slice(
2129+
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
2130+
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2131+
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
2132+
if not dense_shape_dims:
2133+
dense_shape_dims.append(num_rows)
2134+
dense_shape_dims.append(num_cols)
2135+
if sparse_indices is None:
2136+
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
2137+
else:
2138+
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
2139+
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2140+
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2141+
op_name_scope=op_name_scope).output[0]
2142+
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0]
2143+
return sparse_indices, dense_shape
2144+
2145+
20802146
@tf_op("RaggedTensorToSparse")
20812147
class RaggedTensorToSparse:
20822148
@classmethod
20832149
def version_11(cls, ctx, node, **kwargs):
20842150
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
20852151
dense_values = node.input[-1]
20862152
nested_splits = node.input[:-1]
2087-
sparse_indices = None
2088-
dense_shape_dims = []
2089-
for split in nested_splits:
2090-
if ctx.get_dtype(split) != TensorProto.INT64:
2091-
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
2092-
max_int64 = int(utils.get_max_value(np.int64))
2093-
slice1 = GraphBuilder(ctx).make_slice(
2094-
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
2095-
slice2 = GraphBuilder(ctx).make_slice(
2096-
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
2097-
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2098-
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
2099-
if not dense_shape_dims:
2100-
dense_shape_dims.append(num_rows)
2101-
dense_shape_dims.append(num_cols)
2102-
if sparse_indices is None:
2103-
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
2104-
else:
2105-
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
2106-
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2107-
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2108-
op_name_scope=node.name).output[0]
2109-
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0]
2110-
2153+
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
21112154
ctx.replace_all_inputs(node.output[0], sparse_indices)
21122155
ctx.replace_all_inputs(node.output[1], dense_values)
21132156
ctx.replace_all_inputs(node.output[2], dense_shape)
21142157
ctx.remove_node(node.name)
21152158

21162159

2160+
@tf_op("RaggedTensorToTensor")
2161+
class RaggedTensorToTensor:
2162+
@classmethod
2163+
def version_11(cls, ctx, node, **kwargs):
2164+
shape, values, default_value, *row_partition_tensors = node.input
2165+
partition_types = node.get_attr_value("row_partition_types")
2166+
error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2167+
utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types)
2168+
nested_splits = row_partition_tensors
2169+
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
2170+
# A shape of rank 0 means the natural shape should be used.
2171+
if ctx.get_rank(shape) != 0:
2172+
if ctx.get_dtype(shape) != TensorProto.INT64:
2173+
shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0]
2174+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2175+
unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0]
2176+
out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0]
2177+
out_shape_unsq = GraphBuilder(ctx).make_unsqueeze({'data': out_shape, 'axes': [0]})
2178+
amt_idx_in_bounds = ctx.make_node("Sub", [out_shape_unsq, sparse_indices]).output[0]
2179+
amt_in_bounds_flat = ctx.make_node("ReduceMin", [amt_idx_in_bounds], attr={'axes': [1], 'keepdims': False})
2180+
idx_in_bounds = ctx.make_node("Greater", [amt_in_bounds_flat.output[0], const_zero_int64]).output[0]
2181+
sparse_indices = ctx.make_node("Compress", [sparse_indices, idx_in_bounds], attr={'axis': 0}).output[0]
2182+
values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0]
2183+
else:
2184+
out_shape = dense_shape
2185+
expand_node = ctx.make_node("Expand", [default_value, out_shape])
2186+
node.type = "ScatterND"
2187+
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values])
2188+
2189+
21172190
@tf_op("RaggedRange")
21182191
class RaggedRange:
21192192
@classmethod

0 commit comments

Comments
 (0)