diff --git a/tests/test_backend.py b/tests/test_backend.py index c51c4e982..c01429634 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3765,6 +3765,23 @@ def func(indices, dense_shape, new_shape, shape_pad): self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: indices_val, _INPUT1: dense_shape_val, _INPUT2: new_shape_val, _INPUT3: shape_pad_val}) + @check_tf_min_version("1.14", "ragged needs tf 1.14") + @check_opset_min_version(11, "CumSum") + def test_ragged_tensor_to_sparse(self): + splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32) + splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32) + dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32) + def func(splits1, splits2, rt_dense_values): + x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True) + s = x.to_sparse() + indices, values, shape = s.indices, s.values, s.dense_shape + indices = tf.identity(indices, name=_TFOUTPUT) + values = tf.identity(values, name=_TFOUTPUT1) + shape = tf.identity(shape, name=_TFOUTPUT2) + return indices, values, shape + self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], + {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val}) + @check_tf_min_version("1.14", "ragged needs tf 1.14") @check_opset_min_version(11, "Range") def test_ragged_range_float(self): diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ea7294152..9182f8fb5 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -2036,6 +2036,65 @@ def version_11(cls, ctx, node, **kwargs): ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, sparse_vals]) +def ragged_lengths_to_sparse_indices(ctx, ragged_lens): + const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0] + num_cols = ctx.make_node("ReduceMax", [ragged_lens], attr={'axes': [0], 'keeepdims': True}).output[0] + num_rows = ctx.make_node("Shape", [ragged_lens]).output[0] + range_len = ctx.make_node("Mul", [num_cols, num_rows]).output[0] + + # ORT seems to have a shape inference bug for the Range node. Use CumSum instead. + one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1]) + ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0] + range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0] + #const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0] + #range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0] + + col_indices_dense = ctx.make_node("Mod", [range_node, num_cols]).output[0] + row_indices_dense = ctx.make_node("Div", [range_node, num_cols]).output[0] + row_lens_dense = ctx.make_node("Gather", [ragged_lens, row_indices_dense]).output[0] + indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0] + col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0] + row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0] + return num_rows, num_cols, row_indices, col_indices + + +@tf_op("RaggedTensorToSparse") +class RaggedTensorToSparse: + @classmethod + def version_11(cls, ctx, node, **kwargs): + # https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions + dense_values = node.inputs[-1] + nested_splits = node.inputs[:-1] + sparse_indices = None + dense_shape_dims = [] + for split in nested_splits: + if ctx.get_dtype(split.output[0]) != TensorProto.INT64: + split = ctx.make_node("Cast", [split.output[0]], attr={'to': TensorProto.INT64}) + max_int64 = int(utils.get_max_value(np.int64)) + slice1 = GraphBuilder(ctx).make_slice( + {"data": split.output[0], "ends": [max_int64], "starts": [1], "axes": [0]}) + slice2 = GraphBuilder(ctx).make_slice( + {"data": split.output[0], "ends": [-1], "starts": [0], "axes": [0]}) + ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0] + num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens) + if not dense_shape_dims: + dense_shape_dims.append(num_rows) + dense_shape_dims.append(num_cols) + if sparse_indices is None: + row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]}) + else: + row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0] + col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]}) + sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1}, + op_name_scope=node.name).output[0] + dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0] + + ctx.replace_all_inputs(node.output[0], sparse_indices) + ctx.replace_all_inputs(node.output[1], dense_values.output[0]) + ctx.replace_all_inputs(node.output[2], dense_shape) + ctx.remove_node(node.name) + + @tf_op("RaggedRange") class RaggedRange: @classmethod @@ -2076,34 +2135,17 @@ def version_11(cls, ctx, node, **kwargs): const_zero_list = ctx.make_const(utils.make_name("const_zero_list"), np.array([0], dtype=np.int64)).output[0] - max_row_len = ctx.make_node("ReduceMax", [row_lens], attr={'axes': [0], 'keeepdims': False}).output[0] - inp_shape = ctx.make_node("Shape", [row_lens]).output[0] - range_len = ctx.make_node("Mul", [max_row_len, inp_shape]).output[0] - - # ORT seems to have a shape inference bug for the Range node. Use CumSum instead. - one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1]) - ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0] - range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0] - #const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0] - #range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0] - - col_indices_dense = ctx.make_node("Mod", [range_node, max_row_len]).output[0] - row_indices_dense = ctx.make_node("Div", [range_node, max_row_len]).output[0] - row_lens_dense = ctx.make_node("Gather", [row_lens, row_indices_dense]).output[0] - indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0] - col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0] - row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0] - + num_rows, _, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, row_lens) split_ends = ctx.make_node("CumSum", [row_lens, const_zero_int64]).output[0] splits_out = ctx.make_node("Concat", [const_zero_list, split_ends], attr={'axis': 0}).output[0] col_indices_cast = ctx.make_node("Cast", [col_indices], attr={'to': data_dtype}).output[0] if ctx.get_rank(starts) != 1: - starts = ctx.make_node("Expand", [starts, inp_shape]).output[0] + starts = ctx.make_node("Expand", [starts, num_rows]).output[0] if ctx.get_rank(deltas) != 1: - deltas = ctx.make_node("Expand", [deltas, inp_shape]).output[0] + deltas = ctx.make_node("Expand", [deltas, num_rows]).output[0] gather_starts = ctx.make_node("Gather", [starts, row_indices]).output[0] gather_deltas = ctx.make_node("Gather", [deltas, row_indices]).output[0]