From 64de3ae6760d8e78cce05ef7c9084799f8034536 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 3 Nov 2020 18:17:00 -0500 Subject: [PATCH] Added support for DynamicPartition Signed-off-by: Tom Wildenhain --- tests/test_backend.py | 24 ++++++++++++++++++++++++ tf2onnx/onnx_opset/tensor.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 54170610e..5c3974ccb 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3258,6 +3258,30 @@ def func(x): #self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val}) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_opset_min_version(9, "Compress") + def test_dynamic_partition_both_vector(self): + data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32) + part_val = np.array([0, 0, 1, 1, 0, 2, 1, 0], dtype=np.int32) + def func(data, partitions): + p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3) + p1_ = tf.identity(p1, name=_TFOUTPUT) + p2_ = tf.identity(p2, name=_TFOUTPUT1) + p3_ = tf.identity(p3, name=_TFOUTPUT2) + return p1_, p2_, p3_ + self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val}) + + @check_opset_min_version(9, "Compress") + def test_dynamic_partition_data_tensor(self): + data_val = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], dtype=np.float32) + part_val = np.array([0, 2, 1, 0, 1], dtype=np.int32) + def func(data, partitions): + p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3) + p1_ = tf.identity(p1, name=_TFOUTPUT) + p2_ = tf.identity(p2, name=_TFOUTPUT1) + p3_ = tf.identity(p3, name=_TFOUTPUT2) + return p1_, p2_, p3_ + self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val}) + @check_opset_min_version(10, "Conv2DBackpropInput") def test_Conv2DBackpropInput_const(self): input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32) diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index b2449518d..da06a86ae 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -1790,6 +1790,34 @@ def version_11(cls, ctx, node, **kwargs): # FIXME: the indices in onnx are not the same as in tensorflow. +@tf_op("DynamicPartition") +class DynamicPartition: + @classmethod + def version_9(cls, ctx, node, **kwargs): + # For desired behavior, see diagram: https://www.tensorflow.org/api_docs/python/tf/raw_ops/DynamicPartition + data_inp = node.input[0] + partition_inp = node.input[1] + partition_shape = ctx.get_shape(partition_inp) + num_partitions = node.get_attr_value('num_partitions') + utils.make_sure(partition_shape is not None, "DynamicPartition requires known rank") + utils.make_sure(len(partition_shape) == 1, "DynamicPartition only implemented for partitions of rank 1") + # Put partitions into OneHot format + range_val = np.arange(num_partitions, dtype=np.int32).reshape([num_partitions, 1]) + range_const = ctx.make_const(utils.make_name('range_const'), range_val) + equal_node = ctx.make_node("Equal", [partition_inp, range_const.output[0]]) + # Cast bool to int since ORT doesn't implement Split on bool. + equal_int32 = ctx.make_node("Cast", [equal_node.output[0]], attr={"to": TensorProto.INT32}) + split_node = ctx.make_node("Split", [equal_int32.output[0]], output_count=num_partitions, attr={'axis': 0}) + for i in range(num_partitions): + cond_bools = ctx.make_node("Cast", [split_node.output[i]], attr={"to": TensorProto.BOOL}) + squeeze_node = ctx.make_node("Squeeze", [cond_bools.output[0]], attr={'axes': [0]}) + compress_node = ctx.make_node("Compress", [data_inp, squeeze_node.output[0]], attr={'axis': 0}) + ctx.replace_all_inputs(node.output[i], compress_node.output[0]) + ctx.copy_dtype(node.output[i], compress_node.output[0]) + ctx.copy_shape(node.output[i], compress_node.output[0]) + ctx.remove_node(node.name) + + @tf_op("MatrixDiagPart") class MatrixDiagPart: @classmethod