diff --git a/tests/common.py b/tests/common.py index 24cab7052..dec0b6e06 100644 --- a/tests/common.py +++ b/tests/common.py @@ -207,7 +207,7 @@ def requires_custom_ops(message=""): """ Skip until custom ops framework is on PyPI. """ reason = _append_message("test needs custom ops framework", message) try: - import ortcustomops #pylint: disable=import-outside-toplevel,unused-import + import onnxruntime_customops #pylint: disable=import-outside-toplevel,unused-import can_import = True except ModuleNotFoundError: can_import = False diff --git a/tests/test_backend.py b/tests/test_backend.py index 80f950288..f7127c6b4 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2834,6 +2834,36 @@ def func(x): return tf.identity(res, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("1.14", "tf.strings.lower") + @check_opset_min_version(10, "StringNormalizer") + def test_string_lower(self): + text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str) + def func(text1): + x = tf.strings.lower(text1) + x_ = tf.identity(x, name=_TFOUTPUT) + return x_ + self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1}) + + @check_tf_min_version("1.14", "tf.strings.lower") + @check_opset_min_version(10, "StringNormalizer") + def test_string_lower_flat(self): + text_val1 = np.array(["a", "Test 1 2 3", "♠♣", "Hi there", "test test", "♥♦"], dtype=np.str) + def func(text1): + x = tf.strings.lower(text1) + x_ = tf.identity(x, name=_TFOUTPUT) + return x_ + self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1}) + + @check_tf_min_version("1.14", "tf.strings.lower") + @check_opset_min_version(10, "StringNormalizer") + def test_string_upper(self): + text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str) + def func(text1): + x = tf.strings.upper(text1) + x_ = tf.identity(x, name=_TFOUTPUT) + return x_ + self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1}) + @check_opset_min_version(6, "cast") def test_shape_int32(self): x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=np.float32) diff --git a/tests/test_string_ops.py b/tests/test_string_ops.py index d4f83e714..096a5ea6d 100644 --- a/tests/test_string_ops.py +++ b/tests/test_string_ops.py @@ -117,7 +117,7 @@ def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs): def run_onnxruntime(self, model_path, inputs, output_names): """Run test against onnxruntime backend.""" - from ortcustomops import get_library_path + from onnxruntime_customops import get_library_path import onnxruntime as rt opt = rt.SessionOptions() opt.register_custom_ops_library(get_library_path()) diff --git a/tf2onnx/custom_opsets/string_ops.py b/tf2onnx/custom_opsets/string_ops.py index 34d7cc661..341920383 100644 --- a/tf2onnx/custom_opsets/string_ops.py +++ b/tf2onnx/custom_opsets/string_ops.py @@ -121,6 +121,28 @@ def version_1(cls, ctx, node, **kwargs): ctx.copy_shape(output_name, not_node.output[0]) ctx.copy_dtype(output_name, not_node.output[0]) +@tf_op(["StringLower", "StringUpper"]) +class StringLower: + @classmethod + def version_10(cls, ctx, node, **kwargs): + if node.type == "StringLower": + case_action = "LOWER" + else: + case_action = "UPPER" + node.type = "StringNormalizer" + str_input = node.input[0] + rank = ctx.get_rank(node.input[0]) + shape = ctx.get_shape(node.input[0]) + if rank != 1: + ctx.insert_new_node_on_input(node, "Flatten", node.input[0], axis=0) + node.set_attr("case_change_action", case_action) + if rank != 1: + if shape is None or -1 in shape: + new_shape = ctx.make_node("Shape", [str_input]).output[0] + else: + new_shape = ctx.make_const(utils.make_name("shape"), np.array(shape, np.int64)).output[0] + ctx.insert_new_node_on_output("Reshape", node.output[0], inputs=[node.output[0], new_shape]) + @tf_op("SentencepieceOp", domain=constants.CONTRIB_OPS_DOMAIN) class SentencepieceOp: @classmethod