Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def requires_custom_ops(message=""):
""" Skip until custom ops framework is on PyPI. """
reason = _append_message("test needs custom ops framework", message)
try:
import ortcustomops #pylint: disable=import-outside-toplevel,unused-import
import onnxruntime_customops #pylint: disable=import-outside-toplevel,unused-import
can_import = True
except ModuleNotFoundError:
can_import = False
Expand Down
30 changes: 30 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,36 @@ def func(x):
return tf.identity(res, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14", "tf.strings.lower")
@check_opset_min_version(10, "StringNormalizer")
def test_string_lower(self):
text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str)
def func(text1):
x = tf.strings.lower(text1)
x_ = tf.identity(x, name=_TFOUTPUT)
return x_
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})

@check_tf_min_version("1.14", "tf.strings.lower")
@check_opset_min_version(10, "StringNormalizer")
def test_string_lower_flat(self):
text_val1 = np.array(["a", "Test 1 2 3", "♠♣", "Hi there", "test test", "♥♦"], dtype=np.str)
def func(text1):
x = tf.strings.lower(text1)
x_ = tf.identity(x, name=_TFOUTPUT)
return x_
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})

@check_tf_min_version("1.14", "tf.strings.lower")
@check_opset_min_version(10, "StringNormalizer")
def test_string_upper(self):
text_val1 = np.array([["a", "Test 1 2 3", "♠♣"], ["Hi there", "test test", "♥♦"]], dtype=np.str)
def func(text1):
x = tf.strings.upper(text1)
x_ = tf.identity(x, name=_TFOUTPUT)
return x_
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1})

@check_opset_min_version(6, "cast")
def test_shape_int32(self):
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):

def run_onnxruntime(self, model_path, inputs, output_names):
"""Run test against onnxruntime backend."""
from ortcustomops import get_library_path
from onnxruntime_customops import get_library_path
import onnxruntime as rt
opt = rt.SessionOptions()
opt.register_custom_ops_library(get_library_path())
Expand Down
22 changes: 22 additions & 0 deletions tf2onnx/custom_opsets/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,28 @@ def version_1(cls, ctx, node, **kwargs):
ctx.copy_shape(output_name, not_node.output[0])
ctx.copy_dtype(output_name, not_node.output[0])

@tf_op(["StringLower", "StringUpper"])
class StringLower:
@classmethod
def version_10(cls, ctx, node, **kwargs):
if node.type == "StringLower":
case_action = "LOWER"
else:
case_action = "UPPER"
node.type = "StringNormalizer"
str_input = node.input[0]
rank = ctx.get_rank(node.input[0])
shape = ctx.get_shape(node.input[0])
if rank != 1:
ctx.insert_new_node_on_input(node, "Flatten", node.input[0], axis=0)
node.set_attr("case_change_action", case_action)
if rank != 1:
if shape is None or -1 in shape:
new_shape = ctx.make_node("Shape", [str_input]).output[0]
else:
new_shape = ctx.make_const(utils.make_name("shape"), np.array(shape, np.int64)).output[0]
ctx.insert_new_node_on_output("Reshape", node.output[0], inputs=[node.output[0], new_shape])

@tf_op("SentencepieceOp", domain=constants.CONTRIB_OPS_DOMAIN)
class SentencepieceOp:
@classmethod
Expand Down