From c76c0b2f33dd391e7da56f6f532362cae4566114 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 11 May 2021 23:53:41 -0400 Subject: [PATCH 1/2] Add hash table support and frozen graph repairing for from_keras Signed-off-by: Tom Wildenhain --- tests/backend_test_base.py | 5 +- tests/test_api.py | 42 +++++++++++++- tf2onnx/convert.py | 4 +- tf2onnx/tf_loader.py | 109 ++++++++++++++++++++++--------------- 4 files changed, 113 insertions(+), 47 deletions(-) diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index aaa9e29b6..d5b525412 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -75,7 +75,7 @@ def run_onnxcaffe2(self, onnx_graph, inputs): results = prepared_backend.run(inputs) return results - def run_onnxruntime(self, model_path, inputs, output_names): + def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False): """Run test against onnxruntime backend.""" import onnxruntime as rt providers = ['CPUExecutionProvider'] @@ -84,6 +84,9 @@ def run_onnxruntime(self, model_path, inputs, output_names): if gpus is None or len(gpus) > 1: providers = ['CUDAExecutionProvider'] opt = rt.SessionOptions() + if use_custom_ops: + from onnxruntime_customops import get_library_path + opt.register_custom_ops_library(get_library_path()) # in case of issues with the runtime, one can enable more logging # opt.log_severity_level = 0 # opt.log_verbosity_level = 255 diff --git a/tests/test_api.py b/tests/test_api.py index 82420e0f8..931ad43cf 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -15,7 +15,7 @@ import numpy as np import tensorflow as tf -from common import check_tf_min_version, unittest_main +from common import check_tf_min_version, unittest_main, requires_custom_ops from backend_test_base import Tf2OnnxBackendTestBase import tf2onnx @@ -78,6 +78,46 @@ def test_keras_api(self): def test_keras_api_large(self): self._test_keras_api(large_model=True) + @requires_custom_ops() + @check_tf_min_version("2.0") + def test_keras_hashtable(self): + + featCols = [tf.feature_column.numeric_column("f_inp", dtype=tf.float32), + tf.feature_column.indicator_column( + tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1) + )] + featureLayer = tf.keras.layers.DenseFeatures(featCols) + + inputDict = {} + inputDict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32) + inputDict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string) + + inputs = [input for input in inputDict.values()] + standardFeatures = featureLayer(inputDict) + hidden1 = tf.keras.layers.Dense(512, activation='relu')(standardFeatures) + output = tf.keras.layers.Dense(10, activation='softmax')(hidden1) + model = tf.keras.Model(inputs=inputs, outputs=output) + model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error) + + inp1 = np.array([[2.], [3.]], dtype=np.float32) + inp2 = np.array([["a"], ["b"]], dtype=np.str) + k_res = model.predict([inp1, inp2]) + spec = (tf.TensorSpec((None, 1), dtype=tf.float32, name="f_inp"), + tf.TensorSpec((None, 1), tf.string, name="s_inp")) + output_path = os.path.join(self.test_data_directory, "model.onnx") + + from onnx import helper + model_proto, _ = tf2onnx.convert.from_keras( + model, input_signature=spec, opset=self.config.opset, output_path=output_path, + extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)]) + output_names = [n.name for n in model_proto.graph.output] + + o_res = self.run_onnxruntime(output_path, {"f_inp": inp1, "s_inp": inp2}, output_names, use_custom_ops=True) + self.assertAllClose(k_res, o_res[0], rtol=0.3, atol=0.1) + # make sure the original keras model wasn't trashed + k_res2 = model.predict([inp1, inp2]) + self.assertAllClose(k_res2, o_res[0], rtol=0.3, atol=0.1) + @check_tf_min_version("2.0") def test_function(self): def func(x, y): diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 6772b1cba..81f44c3fd 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -333,7 +333,6 @@ def wrap_call(*args, training=False, **kwargs): output_names = [output_tensor.name for output_tensor in concrete_func.outputs if output_tensor.dtype != tf.dtypes.resource] - initialized_tables = None tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names) reverse_lookup = {v: k for k, v in tensors_to_rename.items()} @@ -345,7 +344,8 @@ def wrap_call(*args, training=False, **kwargs): output_names = [reverse_lookup[out] for out in concrete_func.structured_outputs.keys()] with tf.device("/cpu:0"): - frozen_graph = tf_loader.from_function(concrete_func, input_names, output_names, large_model=large_model) + frozen_graph, initialized_tables = \ + tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model) model_proto, external_tensor_storage = _convert_common( frozen_graph, name=model.name, diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 36e798309..522ea3b81 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -52,13 +52,12 @@ def not_implemented_tf_placeholder(*args, **kwargs): try: # pylint: disable=protected-access from tensorflow.python.saved_model.load import _RestoredResource as TfRestoredResourceType + from tensorflow.python.ops.lookup_ops import StaticHashTable as TfStaticHashTableType + from tensorflow.python.training.tracking.base import Trackable as TfTrackableType except ImportError: TfRestoredResourceType = tuple() # isinstance(x, tuple()) is always false - -try: - from tensorflow.python.training.tracking.tracking import AutoTrackable as TfAutoTrackableType -except ImportError: - TfAutoTrackableType = tuple() + TfStaticHashTableType = tuple() + TfTrackableType = tuple() if is_tf2(): convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants @@ -152,6 +151,46 @@ def fix_freezing_errors(graph_def): return graph_def +def from_trackable(trackable, concrete_func, inputs, outputs, large_model): + err_large_model = "model exceeds maximum protobuf size of 2GB. Try setting large_model." + + # Avoid errors due to bug in TF freezing + removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ + _remove_non_variable_resources_from_captures(concrete_func) + + try: + frozen_graph = from_function(concrete_func, inputs, outputs, large_model) + except ValueError as e: + if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): + raise ValueError(err_large_model) + raise e + + # We might be returning the concrete_func so let's put it back in working order + _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) + + table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) + placeholder_to_table_info = {} + _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes, + removed_resource_to_placeholder, placeholder_to_table_info) + + initialized_tables = {} + for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): + h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) + try: + k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) + initialized_tables[n] = (k.numpy(), v.numpy()) + except Exception: # pylint: disable=broad-except + logger.warning("Could not initialize table with shared_name = %r", n) + + for placeholder in removed_resource_to_placeholder.values(): + if placeholder not in placeholder_to_table_info: + logger.error("Could not find table resource to replace placeholder %s", placeholder) + + replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) + + return frozen_graph, initialized_tables + + def from_function(func, input_names, output_names, large_model=False): if large_model: return convert_variables_to_constants_large_model(func) @@ -332,7 +371,27 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, value_dtypes, removed_resource_to_placeholder, placeholder_to_table_info): # pylint: disable=protected-access - for r in trackable.__dict__.values(): + stack = [trackable] + visited = set() + while stack: + r = stack.pop() + visited.add(id(r)) + try: + for trackable_ref in r._checkpoint_dependencies: + if id(trackable_ref.ref) not in visited: + if isinstance(trackable_ref.ref, TfTrackableType): + stack.append(trackable_ref.ref) + except Exception: # pylint: disable=broad-except + continue + for t in r.__dict__.values() if hasattr(r, '__dict__') else []: + if isinstance(t, TfStaticHashTableType) and hasattr(t, '_shared_name'): + table_names.append(t._shared_name.encode()) + key_dtypes.append(t.key_dtype.as_datatype_enum) + value_dtypes.append(t.value_dtype.as_datatype_enum) + table_handle = id(t.resource_handle) + if table_handle in removed_resource_to_placeholder: + table_info = (table_names[-1], key_dtypes[-1], value_dtypes[-1]) + placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info if isinstance(r, TfRestoredResourceType) and hasattr(r, '_create_resource'): try: table_handle = id(r.resource_handle) @@ -346,9 +405,6 @@ def _get_hash_table_info_from_trackable(trackable, table_names, key_dtypes, valu if table_handle in removed_resource_to_placeholder and len(new_names) == 1: table_info = (new_names[0], new_k_dtypes[0], new_v_dtypes[0]) placeholder_to_table_info[removed_resource_to_placeholder[table_handle]] = table_info - if isinstance(r, TfAutoTrackableType): - _get_hash_table_info_from_trackable(r, table_names, key_dtypes, value_dtypes, - removed_resource_to_placeholder, placeholder_to_table_info) def _remove_non_variable_resources_from_captures(concrete_func): @@ -404,7 +460,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" - err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] @@ -461,39 +516,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d outputs = output_names logger.info("Outputs not left as None; will use provided names not structured output names.") - # Avoid errors due to bug in TF freezing - removed_resource_to_placeholder, graph_captures_copy, func_captures_copy = \ - _remove_non_variable_resources_from_captures(concrete_func) - - try: - frozen_graph = from_function(concrete_func, inputs, outputs, large_model) - except ValueError as e: - if any(msg in str(e) for msg in ["exceeds maximum protobuf size of 2GB", "string too long"]): - raise ValueError(err_large_model) - raise e - - # We might be returning the concrete_func so let's put it back in working order - _restore_captured_resources(concrete_func, graph_captures_copy, func_captures_copy) - - table_names, key_dtypes, value_dtypes = get_hash_table_info(frozen_graph) - placeholder_to_table_info = {} - _get_hash_table_info_from_trackable(imported, table_names, key_dtypes, value_dtypes, - removed_resource_to_placeholder, placeholder_to_table_info) - - initialized_tables = {} - for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): - h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) - try: - k, v = lookup_ops.lookup_table_export_v2(h, k_dtype, val_dtype) - initialized_tables[n] = (k.numpy(), v.numpy()) - except Exception: # pylint: disable=broad-except - logger.warning("Could not initialize table with shared_name = %r", n) - - for placeholder in removed_resource_to_placeholder.values(): - if placeholder not in placeholder_to_table_info: - logger.error("Could not find table resource to replace placeholder %s", placeholder) - - replace_placeholders_with_tables(frozen_graph, placeholder_to_table_info) + frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model) return frozen_graph, inputs, outputs, concrete_func, imported, initialized_tables, tensors_to_rename From 37eb0683835dcf484891e10b8fff012d8abf3b47 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Wed, 12 May 2021 00:06:35 -0400 Subject: [PATCH 2/2] Pylint Signed-off-by: Tom Wildenhain --- tests/backend_test_base.py | 13 ++++++++----- tests/test_api.py | 25 ++++++++++++++----------- tests/test_string_ops.py | 13 ++----------- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index d5b525412..95bf4d72b 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -95,14 +95,14 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False results = m.run(output_names, inputs) return results - def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""): + def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False): tensor_storage = ExternalTensorStorage() if large_model else None model_proto = g.make_model("test", external_tensor_storage=tensor_storage) model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage, postfix=postfix) if self.config.backend == "onnxruntime": - y = self.run_onnxruntime(model_path, input_dict, outputs) + y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops) elif self.config.backend == "caffe2": y = self.run_onnxcaffe2(model_proto, input_dict) else: @@ -307,7 +307,8 @@ def get_dtype(info): def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None, - graph_validator=None, as_session=False, large_model=False, premade_placeholders=False): + graph_validator=None, as_session=False, large_model=False, premade_placeholders=False, + use_custom_ops=False): test_tf = not self.config.skip_tf_tests test_tflite = not self.config.skip_tflite_tests run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test @@ -347,7 +348,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit initialized_tables=initialized_tables, **process_args) g = optimizer.optimize_graph(g, catch_errors=False) - actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) + actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model, + use_custom_ops=use_custom_ops) self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -377,7 +379,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit **tfl_process_args) g = optimizer.optimize_graph(g) onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()} - onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite") + onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, + postfix="_from_tflite", use_custom_ops=use_custom_ops) self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) diff --git a/tests/test_api.py b/tests/test_api.py index 931ad43cf..770dbbb69 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -14,8 +14,9 @@ import numpy as np import tensorflow as tf +from onnx import helper -from common import check_tf_min_version, unittest_main, requires_custom_ops +from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version from backend_test_base import Tf2OnnxBackendTestBase import tf2onnx @@ -80,21 +81,24 @@ def test_keras_api_large(self): @requires_custom_ops() @check_tf_min_version("2.0") + @check_opset_min_version(11, "SparseToDense") def test_keras_hashtable(self): - featCols = [tf.feature_column.numeric_column("f_inp", dtype=tf.float32), + feature_cols = [ + tf.feature_column.numeric_column("f_inp", dtype=tf.float32), tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1) - )] - featureLayer = tf.keras.layers.DenseFeatures(featCols) + ) + ] + feature_layer = tf.keras.layers.DenseFeatures(feature_cols) - inputDict = {} - inputDict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32) - inputDict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string) + input_dict = {} + input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32) + input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string) - inputs = [input for input in inputDict.values()] - standardFeatures = featureLayer(inputDict) - hidden1 = tf.keras.layers.Dense(512, activation='relu')(standardFeatures) + inputs = list(input_dict.values()) + standard_features = feature_layer(input_dict) + hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features) output = tf.keras.layers.Dense(10, activation='softmax')(hidden1) model = tf.keras.Model(inputs=inputs, outputs=output) model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error) @@ -106,7 +110,6 @@ def test_keras_hashtable(self): tf.TensorSpec((None, 1), tf.string, name="s_inp")) output_path = os.path.join(self.test_data_directory, "model.onnx") - from onnx import helper model_proto, _ = tf2onnx.convert.from_keras( model, input_signature=spec, opset=self.config.opset, output_path=output_path, extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)]) diff --git a/tests/test_string_ops.py b/tests/test_string_ops.py index 335850416..cd85d4d1c 100644 --- a/tests/test_string_ops.py +++ b/tests/test_string_ops.py @@ -135,17 +135,8 @@ def func(text): def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs): extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)] process_args = {"extra_opset": extra_opset} - return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs) - - def run_onnxruntime(self, model_path, inputs, output_names): - """Run test against onnxruntime backend.""" - from onnxruntime_customops import get_library_path - import onnxruntime as rt - opt = rt.SessionOptions() - opt.register_custom_ops_library(get_library_path()) - m = rt.InferenceSession(model_path, opt) - results = m.run(output_names, inputs) - return results + return self.run_test_case(func, feed_dict, [], output_names_with_port, + use_custom_ops=True, process_args=process_args, **kwargs) @requires_custom_ops("WordpieceTokenizer") @check_tf_min_version("2.0", "tensorflow_text")