Skip to content

Commit f21ca20

Browse files
Added support for DynamicPartition
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 669422c commit f21ca20

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3258,6 +3258,30 @@ def func(x):
32583258
#self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
32593259
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32603260

3261+
@check_opset_min_version(9, "Compress")
3262+
def test_dynamic_partition_both_vector(self):
3263+
data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32)
3264+
part_val = np.array([0, 0, 1, 1, 0, 2, 1, 0], dtype=np.int32)
3265+
def func(data, partitions):
3266+
p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3)
3267+
p1_ = tf.identity(p1, name=_TFOUTPUT)
3268+
p2_ = tf.identity(p2, name=_TFOUTPUT1)
3269+
p3_ = tf.identity(p3, name=_TFOUTPUT2)
3270+
return p1_, p2_, p3_
3271+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val})
3272+
3273+
@check_opset_min_version(9, "Compress")
3274+
def test_dynamic_partition_data_tensor(self):
3275+
data_val = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], dtype=np.float32)
3276+
part_val = np.array([0, 2, 1, 0, 1], dtype=np.int32)
3277+
def func(data, partitions):
3278+
p1, p2, p3 = tf.dynamic_partition(data, partitions, num_partitions=3)
3279+
p1_ = tf.identity(p1, name=_TFOUTPUT)
3280+
p2_ = tf.identity(p2, name=_TFOUTPUT1)
3281+
p3_ = tf.identity(p3, name=_TFOUTPUT2)
3282+
return p1_, p2_, p3_
3283+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: data_val, _INPUT1: part_val})
3284+
32613285
@check_opset_min_version(10, "Conv2DBackpropInput")
32623286
def test_Conv2DBackpropInput_const(self):
32633287
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,35 @@ def version_11(cls, ctx, node, **kwargs):
17901790
# FIXME: the indices in onnx are not the same as in tensorflow.
17911791

17921792

1793+
@tf_op("DynamicPartition")
1794+
class DynamicPartition:
1795+
@classmethod
1796+
def version_9(cls, ctx, node, **kwargs):
1797+
# For desired behavior, see diagram: https://www.tensorflow.org/api_docs/python/tf/raw_ops/DynamicPartition
1798+
data_inp = node.input[0]
1799+
partition_inp = node.input[1]
1800+
data_shape = ctx.get_shape(data_inp)
1801+
partition_shape = ctx.get_shape(partition_inp)
1802+
num_partitions = node.get_attr_value('num_partitions')
1803+
utils.make_sure(partition_shape is not None, "DynamicPartition requires known rank")
1804+
utils.make_sure(len(partition_shape) == 1, "DynamicPartition only implemented for partitions of rank 1")
1805+
# Put partitions into OneHot format
1806+
range_val = np.arange(num_partitions, dtype=np.int32).reshape([num_partitions, 1])
1807+
range_const = ctx.make_const(utils.make_name('range_const'), range_val)
1808+
equal_node = ctx.make_node("Equal", [partition_inp, range_const.output[0]])
1809+
# Cast bool to int since ORT doesn't implement Split on bool.
1810+
equal_int32 = ctx.make_node("Cast", [equal_node.output[0]], attr={"to": TensorProto.INT32})
1811+
split_node = ctx.make_node("Split", [equal_int32.output[0]], output_count=num_partitions, attr={'axis': 0})
1812+
for i in range(num_partitions):
1813+
cond_bools = ctx.make_node("Cast", [split_node.output[i]], attr={"to": TensorProto.BOOL})
1814+
squeeze_node = ctx.make_node("Squeeze", [cond_bools.output[0]], attr={'axes': [0]})
1815+
compress_node = ctx.make_node("Compress", [data_inp, squeeze_node.output[0]], attr={'axis': 0})
1816+
ctx.replace_all_inputs(node.output[i], compress_node.output[0])
1817+
ctx.copy_dtype(node.output[i], compress_node.output[0])
1818+
ctx.copy_shape(node.output[i], compress_node.output[0])
1819+
ctx.remove_node(node.name)
1820+
1821+
17931822
@tf_op("MatrixDiagPart")
17941823
class MatrixDiagPart:
17951824
@classmethod

0 commit comments

Comments
 (0)