Skip to content
8 changes: 8 additions & 0 deletions tests/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def func(text1, text2, text3):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1, _INPUT1: text_val2, _INPUT2: text_val3})

@requires_custom_ops("ReduceJoin")
def test_reduce_join(self):
text_val = np.array([["a", "Test 1 2 3"], ["b", "test test"], ["c", "Hi there Test"]], dtype=np.str)
def func(text):
x_ = tf.strings.reduce_join(text, axis=1, separator="±")
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val})

@requires_custom_ops("StringSplit")
@check_tf_min_version("2.0", "result is sparse not ragged in tf1")
def test_string_split(self):
Expand Down
25 changes: 25 additions & 0 deletions tf2onnx/custom_opsets/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import numpy as np
from onnx.numpy_helper import to_array
from onnx.onnx_pb import TensorProto
from onnx.helper import make_attribute
from tf2onnx import constants, handler
Expand Down Expand Up @@ -86,6 +87,30 @@ def version_1(cls, ctx, node, **kwargs):
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])

@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN)
class ReduceJoin:
@classmethod
def version_1(cls, ctx, node, **kwargs):
node.domain = constants.CONTRIB_OPS_DOMAIN
node.type = "StringJoin"
axis_node = ctx.get_node_by_output(node.input[1])
axis = axis_node.get_attr_value('value')
utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node")
axis = to_array(axis)
new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1)))
separator = node.get_attr_value("separator")
if isinstance(separator, bytes):
separator = separator.decode()
separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object))
ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]])
keep_dims = node.get_attr_value("keep_dims")
if keep_dims:
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze(
{'data': node.output[0], 'axes': [-1]},
name=node.name + '/Unsqueeze'
)
ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node))

@tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
class StringEqual:
@classmethod
Expand Down