Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def run_onnxcaffe2(self, onnx_graph, inputs):
results = prepared_backend.run(inputs)
return results

def run_onnxruntime(self, model_path, inputs, output_names):
def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False):
"""Run test against onnxruntime backend."""
import onnxruntime as rt
providers = ['CPUExecutionProvider']
Expand All @@ -79,6 +79,9 @@ def run_onnxruntime(self, model_path, inputs, output_names):
if gpus is None or len(gpus) > 1:
providers = ['CUDAExecutionProvider']
opt = rt.SessionOptions()
if use_custom_ops:
from onnxruntime_customops import get_library_path
opt.register_custom_ops_library(get_library_path())
# in case of issues with the runtime, one can enable more logging
# opt.log_severity_level = 0
# opt.log_verbosity_level = 255
Expand All @@ -87,14 +90,14 @@ def run_onnxruntime(self, model_path, inputs, output_names):
results = m.run(output_names, inputs)
return results

def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False):
tensor_storage = ExternalTensorStorage() if large_model else None
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage,
postfix=postfix)

if self.config.backend == "onnxruntime":
y = self.run_onnxruntime(model_path, input_dict, outputs)
y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops)
elif self.config.backend == "caffe2":
y = self.run_onnxcaffe2(model_proto, input_dict)
else:
Expand Down Expand Up @@ -299,7 +302,8 @@ def get_dtype(info):
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False,
use_custom_ops=False):
test_tf = not self.config.skip_tf_tests
test_tflite = not self.config.skip_tflite_tests
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
Expand Down Expand Up @@ -339,7 +343,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
initialized_tables=initialized_tables,
**process_args)
g = optimizer.optimize_graph(g, catch_errors=False)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
use_custom_ops=use_custom_ops)

self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down Expand Up @@ -369,7 +374,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
**tfl_process_args)
g = optimizer.optimize_graph(g)
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
postfix="_from_tflite", use_custom_ops=use_custom_ops)

self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down
45 changes: 44 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

import numpy as np
import tensorflow as tf
from onnx import helper

from common import check_tf_min_version, unittest_main
from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version
from backend_test_base import Tf2OnnxBackendTestBase
import tf2onnx

Expand Down Expand Up @@ -74,6 +75,48 @@ def test_keras_api(self):
def test_keras_api_large(self):
self._test_keras_api(large_model=True)

@requires_custom_ops()
@check_tf_min_version("2.0")
@check_opset_min_version(11, "SparseToDense")
def test_keras_hashtable(self):

feature_cols = [
tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1)
)
]
feature_layer = tf.keras.layers.DenseFeatures(feature_cols)

input_dict = {}
input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)

inputs = list(input_dict.values())
standard_features = feature_layer(input_dict)
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features)
output = tf.keras.layers.Dense(10, activation='softmax')(hidden1)
model = tf.keras.Model(inputs=inputs, outputs=output)
model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error)

inp1 = np.array([[2.], [3.]], dtype=np.float32)
inp2 = np.array([["a"], ["b"]], dtype=np.str)
k_res = model.predict([inp1, inp2])
spec = (tf.TensorSpec((None, 1), dtype=tf.float32, name="f_inp"),
tf.TensorSpec((None, 1), tf.string, name="s_inp"))
output_path = os.path.join(self.test_data_directory, "model.onnx")

model_proto, _ = tf2onnx.convert.from_keras(
model, input_signature=spec, opset=self.config.opset, output_path=output_path,
extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)])
output_names = [n.name for n in model_proto.graph.output]

o_res = self.run_onnxruntime(output_path, {"f_inp": inp1, "s_inp": inp2}, output_names, use_custom_ops=True)
self.assertAllClose(k_res, o_res[0], rtol=0.3, atol=0.1)
# make sure the original keras model wasn't trashed
k_res2 = model.predict([inp1, inp2])
self.assertAllClose(k_res2, o_res[0], rtol=0.3, atol=0.1)

@check_tf_min_version("2.0")
def test_function(self):
def func(x, y):
Expand Down
13 changes: 2 additions & 11 deletions tests/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,8 @@ def func(text):
def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):
extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
process_args = {"extra_opset": extra_opset}
return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs)

def run_onnxruntime(self, model_path, inputs, output_names):
"""Run test against onnxruntime backend."""
from onnxruntime_customops import get_library_path
import onnxruntime as rt
opt = rt.SessionOptions()
opt.register_custom_ops_library(get_library_path())
m = rt.InferenceSession(model_path, opt)
results = m.run(output_names, inputs)
return results
return self.run_test_case(func, feed_dict, [], output_names_with_port,
use_custom_ops=True, process_args=process_args, **kwargs)

@requires_custom_ops("WordpieceTokenizer")
@check_tf_min_version("2.0", "tensorflow_text")
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ def wrap_call(*args, training=False, **kwargs):
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
if output_tensor.dtype != tf.dtypes.resource]

initialized_tables = None
tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}

Expand All @@ -341,7 +340,8 @@ def wrap_call(*args, training=False, **kwargs):
output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()]

with tf.device("/cpu:0"):
frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model)
frozen_graph, initialized_tables = \
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)
model_proto, external_tensor_storage = _convert_common(
frozen_graph,
name=model.name,
Expand Down
109 changes: 66 additions & 43 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ def not_implemented_tf_placeholder(*args, **kwargs):
try:
# pylint: disable=protected-access
from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType
from tensorflow.python.ops.lookup_ops import StaticHashTable as TfStaticHashTableType
from tensorflow.python.training.tracking.base import Trackable as TfTrackableType
except ImportError:
TfRestoredResourceType = tuple() # isinstance(x, tuple()) is always false

try:
from tensorflow.python.training.tracking.tracking import AutoTrackable as TfAutoTrackableType
except ImportError:
TfAutoTrackableType = tuple()
TfStaticHashTableType = tuple()
TfTrackableType = tuple()

if is_tf2():
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
Expand Down Expand Up @@ -147,6 +146,46 @@ def fix_freezing_errors(graph_def):
return graph_def


def from_trackable(trackable, concrete_func, inputs, outputs, large_model):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all from _from_saved_model_v2

err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model."

# Avoid errors due to bug in TF freezing
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
_remove_non_variable_resources_from_captures(concrete_func)

try:
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
except ValueError as e:
if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
raise ValueError(err_large_model)
raise e

# We might be returning the concrete_func so let's put it back in working order
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)

table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
placeholder_to_table_info = {}
_get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
removed_resource_to_placeholder, placeholder_to_table_info)

initialized_tables = {}
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
try:
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
initialized_tables[n] = (k.numpy(), v.numpy())
except Exception: # pylint: disable=broad-except
logger.warning("Could not initialize table with shared_name = %r", n)

for placeholder in removed_resource_to_placeholder.values():
if placeholder not in placeholder_to_table_info:
logger.error("Could not find table resource to replace placeholder %s", placeholder)

replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)

return frozen_graph, initialized_tables


def from_function(func, input_names, output_names, large_model=False):
if large_model:
return convert_variables_to_constants_large_model(func)
Expand Down Expand Up @@ -327,7 +366,27 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes,
removed_resource_to_placeholder, placeholder_to_table_info):
# pylint: disable=protected-access
for r in trackable.__dict__.values():
stack = [trackable]
visited = set()
while stack:
r = stack.pop()
visited.add(id(r))
try:
for trackable_ref in r._checkpoint_dependencies:
if id(trackable_ref.ref) not in visited:
if isinstance(trackable_ref.ref, TfTrackableType):
stack.append(trackable_ref.ref)
except Exception: # pylint: disable=broad-except
continue
for t in r.__dict__.values() if hasattr(r, '__dict__') else []:
if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'):
table_names.append(t._shared_name.encode())
key_dtypes.append(t.key_dtype.as_datatype_enum)
value_dtypes.append(t.value_dtype.as_datatype_enum)
table_handle = id(t.resource_handle)
if table_handle in removed_resource_to_placeholder:
table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1])
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'):
try:
table_handle = id(r.resource_handle)
Expand All @@ -341,9 +400,6 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu
if table_handle in removed_resource_to_placeholder and len(new_names) == 1:
table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0])
placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info
if isinstance(r, TfAutoTrackableType):
_get_hash_table_info_from_trackable(r, table_names, key_dtypes, value_dtypes,
removed_resource_to_placeholder, placeholder_to_table_info)


def _remove_non_variable_resources_from_captures(concrete_func):
Expand Down Expand Up @@ -399,7 +455,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
err_no_sig = "No signatures found in model. Try --concrete_function instead."
err_sig_nomatch = "Specified signature not in model %s"
err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."

if tag is None:
tag = ['serve']
Expand Down Expand Up @@ -456,39 +511,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
outputs = output_names
logger.info("Outputs not left as None; will use provided names not structured output names.")

# Avoid errors due to bug in TF freezing
removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \
_remove_non_variable_resources_from_captures(concrete_func)

try:
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
except ValueError as e:
if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]):
raise ValueError(err_large_model)
raise e

# We might be returning the concrete_func so let's put it back in working order
_restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy)

table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph)
placeholder_to_table_info = {}
_get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes,
removed_resource_to_placeholder, placeholder_to_table_info)

initialized_tables = {}
for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes):
h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n)
try:
k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype)
initialized_tables[n] = (k.numpy(), v.numpy())
except Exception: # pylint: disable=broad-except
logger.warning("Could not initialize table with shared_name = %r", n)

for placeholder in removed_resource_to_placeholder.values():
if placeholder not in placeholder_to_table_info:
logger.error("Could not find table resource to replace placeholder %s", placeholder)

replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info)
frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model)

return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename

Expand Down