diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index 7ce7242f8..60c69352b 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -70,7 +70,7 @@ def run_onnxcaffe2(self, onnx_graph, inputs): results = prepared_backend.run(inputs) return results - def run_onnxruntime(self, model_path, inputs, output_names): + def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False): """Run test against onnxruntime backend.""" import onnxruntime as rt providers = ['CPUExecutionProvider'] @@ -79,6 +79,9 @@ def run_onnxruntime(self, model_path, inputs, output_names): if gpus is None or len(gpus) > 1: providers = ['CUDAExecutionProvider'] opt = rt.SessionOptions() + if use_custom_ops: + from onnxruntime_customops import get_library_path + opt.register_custom_ops_library(get_library_path()) # in case of issues with the runtime, one can enable more logging # opt.log_severity_level = 0 # opt.log_verbosity_level = 255 @@ -87,14 +90,14 @@ def run_onnxruntime(self, model_path, inputs, output_names): results = m.run(output_names, inputs) return results - def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""): + def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False): tensor_storage = ExternalTensorStorage() if large_model else None model_proto = g.make_model("test", external_tensor_storage=tensor_storage) model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage, postfix=postfix) if self.config.backend == "onnxruntime": - y = self.run_onnxruntime(model_path, input_dict, outputs) + y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops) elif self.config.backend == "caffe2": y = self.run_onnxcaffe2(model_proto, input_dict) else: @@ -299,7 +302,8 @@ def get_dtype(info): def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None, - graph_validator=None, as_session=False, large_model=False, premade_placeholders=False): + graph_validator=None, as_session=False, large_model=False, premade_placeholders=False, + use_custom_ops=False): test_tf = not self.config.skip_tf_tests test_tflite = not self.config.skip_tflite_tests run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test @@ -339,7 +343,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit initialized_tables=initialized_tables, **process_args) g = optimizer.optimize_graph(g, catch_errors=False) - actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) + actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model, + use_custom_ops=use_custom_ops) self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -369,7 +374,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit **tfl_process_args) g = optimizer.optimize_graph(g) onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()} - onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite") + onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, + postfix="_from_tflite", use_custom_ops=use_custom_ops) self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) diff --git a/tests/test_api.py b/tests/test_api.py index 9694713cc..b83d4bb39 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,8 +10,9 @@ import numpy as np import tensorflow as tf +from onnx import helper -from common import check_tf_min_version, unittest_main +from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version from backend_test_base import Tf2OnnxBackendTestBase import tf2onnx @@ -74,6 +75,48 @@ def test_keras_api(self): def test_keras_api_large(self): self._test_keras_api(large_model=True) + @requires_custom_ops() + @check_tf_min_version("2.0") + @check_opset_min_version(11, "SparseToDense") + def test_keras_hashtable(self): + + feature_cols = [ + tf.feature_column.numeric_column("f_inp", dtype=tf.float32), + tf.feature_column.indicator_column( + tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1) + ) + ] + feature_layer = tf.keras.layers.DenseFeatures(feature_cols) + + input_dict = {} + input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32) + input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string) + + inputs = list(input_dict.values()) + standard_features = feature_layer(input_dict) + hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features) + output = tf.keras.layers.Dense(10, activation='softmax')(hidden1) + model = tf.keras.Model(inputs=inputs, outputs=output) + model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error) + + inp1 = np.array([[2.], [3.]], dtype=np.float32) + inp2 = np.array([["a"], ["b"]], dtype=np.str) + k_res = model.predict([inp1, inp2]) + spec = (tf.TensorSpec((None, 1), dtype=tf.float32, name="f_inp"), + tf.TensorSpec((None, 1), tf.string, name="s_inp")) + output_path = os.path.join(self.test_data_directory, "model.onnx") + + model_proto, _ = tf2onnx.convert.from_keras( + model, input_signature=spec, opset=self.config.opset, output_path=output_path, + extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)]) + output_names = [n.name for n in model_proto.graph.output] + + o_res = self.run_onnxruntime(output_path, {"f_inp": inp1, "s_inp": inp2}, output_names, use_custom_ops=True) + self.assertAllClose(k_res, o_res[0], rtol=0.3, atol=0.1) + # make sure the original keras model wasn't trashed + k_res2 = model.predict([inp1, inp2]) + self.assertAllClose(k_res2, o_res[0], rtol=0.3, atol=0.1) + @check_tf_min_version("2.0") def test_function(self): def func(x, y): diff --git a/tests/test_string_ops.py b/tests/test_string_ops.py index 7f5db80d4..f28ebf924 100644 --- a/tests/test_string_ops.py +++ b/tests/test_string_ops.py @@ -130,17 +130,8 @@ def func(text): def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs): extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)] process_args = {"extra_opset": extra_opset} - return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs) - - def run_onnxruntime(self, model_path, inputs, output_names): - """Run test against onnxruntime backend.""" - from onnxruntime_customops import get_library_path - import onnxruntime as rt - opt = rt.SessionOptions() - opt.register_custom_ops_library(get_library_path()) - m = rt.InferenceSession(model_path, opt) - results = m.run(output_names, inputs) - return results + return self.run_test_case(func, feed_dict, [], output_names_with_port, + use_custom_ops=True, process_args=process_args, **kwargs) @requires_custom_ops("WordpieceTokenizer") @check_tf_min_version("2.0", "tensorflow_text") diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 3b778d08a..450275cf1 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -329,7 +329,6 @@ def wrap_call(*args, training=False, **kwargs): output_names = [output_tensor.name for output_tensor in concrete_func.outputs if output_tensor.dtype != tf.dtypes.resource] - initialized_tables = None tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names) reverse_lookup = {v: k for k, v in tensors_to_rename.items()} @@ -341,7 +340,8 @@ def wrap_call(*args, training=False, **kwargs): output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()] with tf.device("/cpu:0"): - frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model) + frozen_graph, initialized_tables = \ + tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model) model_proto, external_tensor_storage = _convert_common( frozen_graph, name=model.name, diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 2f58fe423..c813e507e 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -47,13 +47,12 @@ def not_implemented_tf_placeholder(*args, **kwargs): try: # pylint: disable=protected-access from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType + from tensorflow.python.ops.lookup_ops import StaticHashTable as TfStaticHashTableType + from tensorflow.python.training.tracking.base import Trackable as TfTrackableType except ImportError: TfRestoredResourceType = tuple() # isinstance(x, tuple()) is always false - -try: - from tensorflow.python.training.tracking.tracking import AutoTrackable as TfAutoTrackableType -except ImportError: - TfAutoTrackableType = tuple() + TfStaticHashTableType = tuple() + TfTrackableType = tuple() if is_tf2(): convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants @@ -147,6 +146,46 @@ def fix_freezing_errors(graph_def): return graph_def +def from_trackable(trackable, concrete_func, inputs, outputs, large_model): + err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model." + + # Avoid errors due to bug in TF freezing + removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ + _remove_non_variable_resources_from_captures(concrete_func) + + try: + frozen_graph = from_function(concrete_func, inputs, outputs, large_model) + except ValueError as e: + if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): + raise ValueError(err_large_model) + raise e + + # We might be returning the concrete_func so let's put it back in working order + _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) + + table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) + placeholder_to_table_info = {} + _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes, + removed_resource_to_placeholder, placeholder_to_table_info) + + initialized_tables = {} + for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): + h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) + try: + k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) + initialized_tables[n] = (k.numpy(), v.numpy()) + except Exception: # pylint: disable=broad-except + logger.warning("Could not initialize table with shared_name = %r", n) + + for placeholder in removed_resource_to_placeholder.values(): + if placeholder not in placeholder_to_table_info: + logger.error("Could not find table resource to replace placeholder %s", placeholder) + + replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) + + return frozen_graph, initialized_tables + + def from_function(func, input_names, output_names, large_model=False): if large_model: return convert_variables_to_constants_large_model(func) @@ -327,7 +366,27 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes, removed_resource_to_placeholder, placeholder_to_table_info): # pylint: disable=protected-access - for r in trackable.__dict__.values(): + stack = [trackable] + visited = set() + while stack: + r = stack.pop() + visited.add(id(r)) + try: + for trackable_ref in r._checkpoint_dependencies: + if id(trackable_ref.ref) not in visited: + if isinstance(trackable_ref.ref, TfTrackableType): + stack.append(trackable_ref.ref) + except Exception: # pylint: disable=broad-except + continue + for t in r.__dict__.values() if hasattr(r, '__dict__') else []: + if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'): + table_names.append(t._shared_name.encode()) + key_dtypes.append(t.key_dtype.as_datatype_enum) + value_dtypes.append(t.value_dtype.as_datatype_enum) + table_handle = id(t.resource_handle) + if table_handle in removed_resource_to_placeholder: + table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1]) + placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'): try: table_handle = id(r.resource_handle) @@ -341,9 +400,6 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu if table_handle in removed_resource_to_placeholder and len(new_names) == 1: table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0]) placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info - if isinstance(r, TfAutoTrackableType): - _get_hash_table_info_from_trackable(r, table_names, key_dtypes, value_dtypes, - removed_resource_to_placeholder, placeholder_to_table_info) def _remove_non_variable_resources_from_captures(concrete_func): @@ -399,7 +455,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" - err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] @@ -456,39 +511,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d outputs = output_names logger.info("Outputs not left as None; will use provided names not structured output names.") - # Avoid errors due to bug in TF freezing - removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ - _remove_non_variable_resources_from_captures(concrete_func) - - try: - frozen_graph = from_function(concrete_func, inputs, outputs, large_model) - except ValueError as e: - if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): - raise ValueError(err_large_model) - raise e - - # We might be returning the concrete_func so let's put it back in working order - _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) - - table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) - placeholder_to_table_info = {} - _get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes, - removed_resource_to_placeholder, placeholder_to_table_info) - - initialized_tables = {} - for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): - h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) - try: - k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) - initialized_tables[n] = (k.numpy(), v.numpy()) - except Exception: # pylint: disable=broad-except - logger.warning("Could not initialize table with shared_name = %r", n) - - for placeholder in removed_resource_to_placeholder.values(): - if placeholder not in placeholder_to_table_info: - logger.error("Could not find table resource to replace placeholder %s", placeholder) - - replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) + frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model) return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename