From 297352a32d581681da5fadee9c28e354594a8741 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Wed, 2 Sep 2020 15:05:01 -0400 Subject: [PATCH] Added support for converting large models --- README.md | 10 ++++++++- tests/backend_test_base.py | 30 ++++++++++++++++++-------- tests/test_backend.py | 9 ++++++++ tests/test_convert.py | 14 ++++++++++++- tf2onnx/convert.py | 24 ++++++++++++++++----- tf2onnx/tf_loader.py | 43 +++++++++++++++++++++++++++++++++----- tf2onnx/tfonnx.py | 9 +++++--- tf2onnx/utils.py | 22 ++++++++++++++++--- 8 files changed, 134 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index a7fdbe041..4ff81e687 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,12 @@ Only valid with parameter `--saved_model`. Specifies which signature to use with Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored. +#### --large_model + +(This is experimental, valid only for TF2.x models) + +Only valid with parameter `--saved_model`. When set, creates a zip file containing the ONNX protobuf model and large tensor values stored externally. This allows for converting models that exceed the 2 GB protobuf limit. + #### --target Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value. @@ -274,7 +280,8 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, - input_names=None, output_names=None): + input_names=None, output_names=None, + const_node_values=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph @@ -289,6 +296,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph, inputs_as_nchw: transpose inputs in list from nchw to nchw input_names: list of input node names in graph, input name format as node_name:port_id output_names: list of output node names in graph, output name format as node_name:port_id + const_node_values: an optional dict mapping node names to tensor values Return: onnx graph """ diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index 4ed90053a..afe226665 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -26,6 +26,8 @@ from tf2onnx import optimizer from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, tf_placeholder, from_function, freeze_session from tf2onnx.tf_loader import tf_optimize, is_tf2 +from tf2onnx.tf_utils import compress_graph_def +from tf2onnx.graph import ExternalTensorStorage class Tf2OnnxBackendTestBase(unittest.TestCase): @@ -72,9 +74,10 @@ def run_onnxruntime(self, model_path, inputs, output_names): results = m.run(output_names, inputs) return results - def run_backend(self, g, outputs, input_dict): - model_proto = g.make_model("test") - model_path = self.save_onnx_model(model_proto, input_dict) + def run_backend(self, g, outputs, input_dict, large_model=False): + tensor_storage = ExternalTensorStorage() if large_model else None + model_proto = g.make_model("test", external_tensor_storage=tensor_storage) + model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage) if self.config.backend == "onnxruntime": y = self.run_onnxruntime(model_path, input_dict, outputs) @@ -86,7 +89,8 @@ def run_backend(self, g, outputs, input_dict): def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, - check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False): + check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False, + large_model=False): # optional - passed to process_tf_graph if process_args is None: process_args = {} @@ -121,7 +125,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, - input_names=list(feed_dict.keys()), output_names=output_names_with_port) + input_names=list(feed_dict.keys()), + output_names=output_names_with_port, + large_model=large_model) else: # # use graph to execute the tensorflow func @@ -151,6 +157,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit tf_reset_default_graph() with tf_session() as sess: + const_node_values = None + if large_model: + const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') if self.config.is_debug_mode: @@ -161,9 +170,11 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit g = process_tf_graph(sess.graph, opset=self.config.opset, input_names=list(feed_dict.keys()), output_names=output_names_with_port, - target=self.config.target, **process_args) + target=self.config.target, + const_node_values=const_node_values, + **process_args) g = optimizer.optimize_graph(g) - actual = self.run_backend(g, output_names_with_port, onnx_feed_dict) + actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) for expected_val, actual_val in zip(expected, actual): if check_value: @@ -180,10 +191,11 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit return g - def save_onnx_model(self, model_proto, feed_dict, postfix=""): + def save_onnx_model(self, model_proto, feed_dict, postfix="", external_tensor_storage=None): target_path = utils.save_onnx_model(self.test_data_directory, self._testMethodName + postfix, feed_dict, model_proto, include_test_data=self.config.is_debug_mode, - as_text=self.config.is_debug_mode) + as_text=self.config.is_debug_mode, + external_tensor_storage=external_tensor_storage) self.logger.debug("create model file: %s", target_path) return target_path diff --git a/tests/test_backend.py b/tests/test_backend.py index 44bf95704..59bc935b2 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -819,6 +819,15 @@ def func(x): return tf.identity(x_, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + @check_tf_min_version("2.2") + def test_large_model_format(self): + x_val = np.array([2.0], dtype=np.float32) + y_const = np.arange(2000, dtype=np.float32) + def func(x): + x_ = tf.multiply(x, tf.constant(y_const)) + return tf.identity(x_, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, large_model=True) + @check_target('rs6', 'GatherNd') def test_gathernd(self): x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32) diff --git a/tests/test_convert.py b/tests/test_convert.py index a05b5db65..e3e0aeef6 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -8,7 +8,7 @@ import unittest from tf2onnx import convert - +from common import check_tf_min_version def run_test_case(args): """ run case and clean up """ @@ -33,6 +33,18 @@ def test_convert_saved_model(self): '--output', 'converted_saved_model.onnx'])) + @check_tf_min_version("2.2") + def test_convert_large_model(self): + """ convert saved model to onnx large model format """ + self.assertTrue(run_test_case(['', + '--large_model', + '--saved-model', + 'tests/models/regression/saved_model', + '--tag', + 'serve', + '--output', + 'converted_saved_model.zip'])) + def test_convert_graphdef(self): """ convert graphdef """ self.assertTrue(run_test_case(['', diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 2e4c7884f..463dabbb5 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -22,6 +22,8 @@ from tf2onnx.tfonnx import process_tf_graph from tf2onnx import constants, logging, utils, optimizer from tf2onnx import tf_loader +from tf2onnx.graph import ExternalTensorStorage +from tf2onnx.tf_utils import compress_graph_def # pylint: disable=unused-argument @@ -53,6 +55,7 @@ def get_args(): help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)") parser.add_argument("--checkpoint", help="input from checkpoint") parser.add_argument("--keras", help="input from keras model") + parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true") parser.add_argument("--output", help="output model file") parser.add_argument("--inputs", help="model input_names") parser.add_argument("--outputs", help="model output_names") @@ -129,7 +132,8 @@ def main(): model_path = args.checkpoint if args.saved_model: graph_def, inputs, outputs = tf_loader.from_saved_model( - args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function) + args.saved_model, args.inputs, args.outputs, args.tag, + args.signature_def, args.concrete_function, args.large_model) model_path = args.saved_model if args.keras: graph_def, inputs, outputs = tf_loader.from_keras( @@ -141,6 +145,9 @@ def main(): logger.info("outputs: %s", outputs) with tf.Graph().as_default() as tf_graph: + const_node_values = None + if args.large_model: + const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') with tf_loader.tf_session(graph=tf_graph): g = process_tf_graph(tf_graph, @@ -152,17 +159,24 @@ def main(): shape_override=args.shape_override, input_names=inputs, output_names=outputs, - inputs_as_nchw=args.inputs_as_nchw) + inputs_as_nchw=args.inputs_as_nchw, + const_node_values=const_node_values) onnx_graph = optimizer.optimize_graph(g) - model_proto = onnx_graph.make_model("converted from {}".format(model_path)) + + tensor_storage = ExternalTensorStorage() if args.large_model else None + model_proto = onnx_graph.make_model("converted from {}".format(model_path), external_tensor_storage=tensor_storage) # write onnx graph logger.info("") logger.info("Successfully converted TensorFlow model %s to ONNX", model_path) if args.output: - utils.save_protobuf(args.output, model_proto) - logger.info("ONNX model is saved at %s", args.output) + if args.large_model: + utils.save_onnx_zip(args.output, model_proto, tensor_storage) + logger.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", args.output) + else: + utils.save_protobuf(args.output, model_proto) + logger.info("ONNX model is saved at %s", args.output) else: logger.info("To export ONNX model to file, please run with `--output` option") diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index f062db46c..9a0ad6eb1 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -95,8 +95,32 @@ def inputs_without_resource(sess, input_names): pass return input_names +def convert_variables_to_constants_large_model(func): + # For large models we use internal tf methods as a hack + + if tf.__version__.startswith("2.2."): + try: + from tensorflow.python.framework.convert_to_constants import \ + _convert_variables_to_constants_v2_impl # pylint: disable=protected-access + except ImportError: + _not_implemented_tf_placeholder("_convert_variables_to_constants_v2_impl")() + frozen_graph_def, _ = \ + _convert_variables_to_constants_v2_impl(func, lower_control_flow=False, aggressive_inlining=False) + return frozen_graph_def + + try: + from tensorflow.python.framework.convert_to_constants import \ + _FunctionConverterData, _replace_variables_by_constants # pylint: disable=protected-access + except ImportError: + _not_implemented_tf_placeholder("_replace_variables_by_constants")() + converter_data = _FunctionConverterData(func=func, lower_control_flow=False, aggressive_inlining=False) + frozen_graph_def, _ = _replace_variables_by_constants(converter_data=converter_data) + return frozen_graph_def + +def from_function(func, input_names, output_names, large_model=False): + if large_model: + return convert_variables_to_constants_large_model(func) -def from_function(func, input_names, output_names): frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False) graph_def = frozen_func.graph.as_graph_def(add_shapes=True) # output_names = [i.name for i in frozen_func.outputs] @@ -223,7 +247,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa return frozen_graph, input_names, output_names -def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index): +def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, + concrete_function_index, large_model): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" @@ -234,6 +259,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]" err_no_sig = "No signatures found in model. Try --concrete_function instead." err_sig_nomatch = "Specified signature not in model %s" + err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag." if tag is None: tag = ['serve'] @@ -274,18 +300,25 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d if output_names: outputs = list(set(output_names) & set(outputs)) - frozen_graph = from_function(concrete_func, inputs, outputs) + try: + frozen_graph = from_function(concrete_func, inputs, outputs, large_model) + except ValueError as e: + if "exceeds maximum protobuf size of 2GB" in str(e): + raise ValueError(err_large_model) + raise e + return frozen_graph, inputs, outputs -def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None): +def from_saved_model(model_path, input_names, output_names, tag=None, + signatures=None, concrete_function=None, large_model=False): """Load tensorflow graph from saved_model.""" if signatures is None: signatures = [] tf_reset_default_graph() if is_tf2(): frozen_graph, input_names, output_names = \ - _from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function) + _from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model) else: with tf_session() as sess: frozen_graph, input_names, output_names = \ diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index a2ed3e6b8..42e0819c8 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -334,7 +334,7 @@ def run_rewriters(g, funcs, continue_on_error): def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, extra_opset=None, shape_override=None, inputs_as_nchw=None, - input_names=None, output_names=None, is_subgraph=False): + input_names=None, output_names=None, is_subgraph=False, const_node_values=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph @@ -349,6 +349,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No inputs_as_nchw: transpose inputs in list from nchw to nchw input_names: list of input node names in graph, input name format as node_name:port_id output_names: list of output node names in graph, output name format as node_name:port_id + const_node_values: a dict returned by compress_graph_def mapping node names to tensor values Return: onnx graph """ @@ -377,7 +378,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No if target is None: target = constants.DEFAULT_TARGET - onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override) + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \ + tensorflow_to_onnx(tf_graph, shape_override, const_node_values) if not is_subgraph: # make tf2onnx internal subgraphs from the tensorflow subgraphs ordered_func = resolve_functions(tf_graph) @@ -387,7 +389,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, - f_inputs_names, f_output_names, is_subgraph=True) + f_inputs_names, f_output_names, is_subgraph=True, + const_node_values=const_node_values) fg.graph_name = func.name fg.func_inputs = f_inputs_names set_function(func.name, fg) diff --git a/tf2onnx/utils.py b/tf2onnx/utils.py index 85df4087c..24c6e4ce2 100644 --- a/tf2onnx/utils.py +++ b/tf2onnx/utils.py @@ -13,6 +13,7 @@ import re import shutil import tempfile +import zipfile import requests from requests.adapters import HTTPAdapter @@ -161,7 +162,8 @@ def find_opset(opset): return opset -def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False): +def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False, + external_tensor_storage=None): """Save onnx model as file. Save a pbtxt file as well if as_text is True""" save_path = save_path_root if not os.path.exists(save_path): @@ -181,12 +183,26 @@ def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, incl save_protobuf(data_full_path, t) i += 1 - target_path = os.path.join(save_path, onnx_file_name + ".onnx") - save_protobuf(target_path, model_proto) + if external_tensor_storage is None: + target_path = os.path.join(save_path, onnx_file_name + ".onnx") + save_protobuf(target_path, model_proto) + else: + zip_path = os.path.join(save_path, onnx_file_name + ".zip") + save_onnx_zip(zip_path, model_proto, external_tensor_storage) + with zipfile.ZipFile(zip_path, 'r') as z: + z.extractall(save_path) + target_path = os.path.join(save_path, "__MODEL_PROTO.onnx") + if as_text: save_protobuf(target_path + ".pbtxt", model_proto, as_text=True) + return target_path +def save_onnx_zip(target_path, model_proto, external_tensor_storage): + with zipfile.ZipFile(target_path, 'w') as z: + z.writestr("__MODEL_PROTO.onnx", model_proto.SerializeToString()) + for k, v in external_tensor_storage.name_to_tensor_data.items(): + z.writestr(k, v) def make_sure(bool_val, error_msg, *args): if not bool_val: