Skip to content

Commit fbaa6f4

Browse files
Added support for converting RaggedRange
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1c9c02d commit fbaa6f4

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

tests/test_backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3719,6 +3719,45 @@ def func(indices, dense_shape, new_shape, shape_pad):
37193719
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: indices_val, _INPUT1: dense_shape_val,
37203720
_INPUT2: new_shape_val, _INPUT3: shape_pad_val})
37213721

3722+
@check_opset_min_version(11, "Range")
3723+
def test_ragged_range_float(self):
3724+
starts_val = np.array([0, 0, 1, 10, 0.5, 0.5], dtype=np.float32)
3725+
limits_val = np.array([-5, -2, 7, 100, 1, 1], dtype=np.float32)
3726+
deltas_val = np.array([-1, 1, 2, 20, 1, 1.1], dtype=np.float32)
3727+
def func(starts, limits, deltas):
3728+
rt_nested_splits, rt_dense_values = tf.raw_ops.RaggedRange(starts=starts, limits=limits, deltas=deltas)
3729+
rt_nested_splits_ = tf.identity(rt_nested_splits, name=_TFOUTPUT)
3730+
rt_dense_values_ = tf.identity(rt_dense_values, name=_TFOUTPUT1)
3731+
return rt_nested_splits_, rt_dense_values_
3732+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3733+
_INPUT2: deltas_val})
3734+
3735+
@check_opset_min_version(11, "Range")
3736+
def test_ragged_range_int(self):
3737+
starts_val = np.array([0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.int32)
3738+
limits_val = np.array([-6, -5, -4, -1, 0, 1, 4, 5, 6, 2, -2], dtype=np.int32)
3739+
deltas_val = np.array([-5, -5, -5, -5, 5, 5, 5, 5, 5, 1, -1], dtype=np.int32)
3740+
def func(starts, limits, deltas):
3741+
rt_nested_splits, rt_dense_values = tf.raw_ops.RaggedRange(starts=starts, limits=limits, deltas=deltas)
3742+
rt_nested_splits_ = tf.identity(rt_nested_splits, name=_TFOUTPUT)
3743+
rt_dense_values_ = tf.identity(rt_dense_values, name=_TFOUTPUT1)
3744+
return rt_nested_splits_, rt_dense_values_
3745+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3746+
_INPUT2: deltas_val})
3747+
3748+
@check_opset_min_version(11, "Range")
3749+
def test_ragged_range_scalar(self):
3750+
starts_val = np.array(0, dtype=np.int32)
3751+
limits_val = np.array([5, -1, -1, 2, 7, 100, 4, 5, 6], dtype=np.int32)
3752+
deltas_val = np.array(1, dtype=np.int32)
3753+
def func(starts, limits, deltas):
3754+
rt_nested_splits, rt_dense_values = tf.raw_ops.RaggedRange(starts=starts, limits=limits, deltas=deltas)
3755+
rt_nested_splits_ = tf.identity(rt_nested_splits, name=_TFOUTPUT)
3756+
rt_dense_values_ = tf.identity(rt_dense_values, name=_TFOUTPUT1)
3757+
return rt_nested_splits_, rt_dense_values_
3758+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3759+
_INPUT2: deltas_val})
3760+
37223761
@check_opset_min_version(9, "Compress")
37233762
def test_dynamic_partition_both_vector(self):
37243763
data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32)

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,13 @@ def get_shape(self, name):
951951
return shape
952952
return shape
953953

954+
def get_rank(self, name):
955+
"""Returns len(get_shape(name)) or None if shape is None"""
956+
shape = self.get_shape(name)
957+
if shape is None:
958+
return None
959+
return len(shape)
960+
954961
def set_shape(self, name, val):
955962
"""Set new shape of node."""
956963
if isinstance(val, np.ndarray):

tf2onnx/onnx_opset/tensor.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,6 +2034,86 @@ def version_11(cls, ctx, node, **kwargs):
20342034
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, sparse_vals])
20352035

20362036

2037+
@tf_op("RaggedRange")
2038+
class RaggedRange:
2039+
@classmethod
2040+
def version_11(cls, ctx, node, **kwargs):
2041+
starts, limits, deltas = node.input
2042+
data_dtype = ctx.get_dtype(starts)
2043+
data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
2044+
data_is_float = np.dtype(data_np_dtype).kind == 'f'
2045+
2046+
if data_is_float:
2047+
sub_node = ctx.make_node("Sub", [limits, starts]).output[0]
2048+
div_node = ctx.make_node("Div", [sub_node, deltas]).output[0]
2049+
ceil_node = ctx.make_node("Ceil", [div_node]).output[0]
2050+
row_lens = ctx.make_node("Cast", [ceil_node], attr={'to': TensorProto.INT64}).output[0]
2051+
2052+
else:
2053+
# compute ceil(a/b) with ints
2054+
starts_cast = ctx.make_node("Cast", [starts], attr={'to': TensorProto.INT64}).output[0]
2055+
limits_cast = ctx.make_node("Cast", [limits], attr={'to': TensorProto.INT64}).output[0]
2056+
deltas_cast = ctx.make_node("Cast", [deltas], attr={'to': TensorProto.INT64}).output[0]
2057+
sub_node = ctx.make_node("Sub", [limits_cast, starts_cast]).output[0]
2058+
div_node = ctx.make_node("Div", [sub_node, deltas_cast]).output[0]
2059+
mul_node = ctx.make_node("Mul", [div_node, deltas_cast]).output[0]
2060+
eq_node = ctx.make_node("Equal", [mul_node, sub_node]).output[0]
2061+
ne_node = ctx.make_node("Not", [eq_node]).output[0]
2062+
# we want to round up if it isn't evenly divisible
2063+
offset = ctx.make_node("Cast", [ne_node], attr={'to': TensorProto.INT64}).output[0]
2064+
row_lens = ctx.make_node("Add", [div_node, offset]).output[0]
2065+
2066+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2067+
if ctx.opset <= 11:
2068+
const_zero_double = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.float64)).output[0]
2069+
row_lens = ctx.make_node("Cast", [row_lens], attr={'to': TensorProto.DOUBLE}).output[0]
2070+
row_lens = ctx.make_node("Max", [row_lens, const_zero_double]).output[0]
2071+
row_lens = ctx.make_node("Cast", [row_lens], attr={'to': TensorProto.INT64}).output[0]
2072+
else:
2073+
row_lens = ctx.make_node("Max", [row_lens, const_zero_int64]).output[0]
2074+
2075+
const_zero_list = ctx.make_const(utils.make_name("const_zero_list"), np.array([0], dtype=np.int64)).output[0]
2076+
2077+
max_row_len = ctx.make_node("ReduceMax", [row_lens], attr={'axes': [0], 'keeepdims': False}).output[0]
2078+
inp_shape = ctx.make_node("Shape", [row_lens]).output[0]
2079+
range_len = ctx.make_node("Mul", [max_row_len, inp_shape]).output[0]
2080+
2081+
# ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2082+
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2083+
ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0]
2084+
range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0]
2085+
#const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2086+
#range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2087+
2088+
col_indices_dense = ctx.make_node("Mod", [range_node, max_row_len]).output[0]
2089+
row_indices_dense = ctx.make_node("Div", [range_node, max_row_len]).output[0]
2090+
row_lens_dense = ctx.make_node("Gather", [row_lens, row_indices_dense]).output[0]
2091+
indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0]
2092+
col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0]
2093+
row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0]
2094+
2095+
2096+
split_ends = ctx.make_node("CumSum", [row_lens, const_zero_int64]).output[0]
2097+
splits_out = ctx.make_node("Concat", [const_zero_list, split_ends], attr={'axis': 0}).output[0]
2098+
col_indices_cast = ctx.make_node("Cast", [col_indices], attr={'to': data_dtype}).output[0]
2099+
2100+
if ctx.get_rank(starts) != 1:
2101+
starts = ctx.make_node("Expand", [starts, inp_shape]).output[0]
2102+
2103+
if ctx.get_rank(deltas) != 1:
2104+
deltas = ctx.make_node("Expand", [deltas, inp_shape]).output[0]
2105+
2106+
gather_starts = ctx.make_node("Gather", [starts, row_indices]).output[0]
2107+
gather_deltas = ctx.make_node("Gather", [deltas, row_indices]).output[0]
2108+
2109+
mul_node = ctx.make_node("Mul", [col_indices_cast, gather_deltas], op_name_scope=node.name).output[0]
2110+
dense_vals_out = ctx.make_node("Add", [gather_starts, mul_node], op_name_scope=node.name).output[0]
2111+
2112+
ctx.replace_all_inputs(node.output[0], splits_out)
2113+
ctx.replace_all_inputs(node.output[1], dense_vals_out)
2114+
ctx.remove_node(node.name)
2115+
2116+
20372117
@tf_op("SparseReshape")
20382118
class SparseReshape:
20392119
@classmethod

0 commit comments

Comments
 (0)