Skip to content

Commit 4bf08f4

Browse files
Added support for converting large models
1 parent b2b81a3 commit 4bf08f4

File tree

7 files changed

+98
-22
lines changed

7 files changed

+98
-22
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ Only valid with parameter `--saved_model`. Specifies which signature to use with
193193

194194
Only valid with parameter `--saved_model`. If a model contains a list of concrete functions, under the function name `__call__` (as can be viewed using the command `saved_model_cli show --all`), this parameter is a 0-based integer specifying which function in that list should be converted. This parameter takes priority over `--signature_def`, which will be ignored.
195195

196+
#### --large_model
197+
198+
(This is experimental, valid only for TF2.x models)
199+
200+
Only valid with parameter `--saved_model`. When set, creates a zip file containing the ONNX protobuf model and large tensor values stored externally. This allows for converting models that exceed the 2 GB protobuf limit.
201+
196202
#### --target
197203

198204
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.
@@ -274,7 +280,8 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
274280
opset=None, custom_op_handlers=None,
275281
custom_rewriter=None, extra_opset=None,
276282
shape_override=None, inputs_as_nchw=None,
277-
input_names=None, output_names=None):
283+
input_names=None, output_names=None,
284+
const_node_values=None):
278285
"""Convert tensorflow graph to onnx graph.
279286
Args:
280287
tf_graph: tensorflow graph
@@ -289,6 +296,7 @@ tf2onnx.tfonnx.process_tf_graph(tf_graph,
289296
inputs_as_nchw: transpose inputs in list from nchw to nchw
290297
input_names: list of input node names in graph, input name format as node_name:port_id
291298
output_names: list of output node names in graph, output name format as node_name:port_id
299+
const_node_values: an optional dict mapping node names to tensor values
292300
Return:
293301
onnx graph
294302
"""

tests/backend_test_base.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from tf2onnx import optimizer
2727
from tf2onnx.tf_loader import tf_reset_default_graph, tf_session, tf_placeholder, from_function, freeze_session
2828
from tf2onnx.tf_loader import tf_optimize, is_tf2
29+
from tf2onnx.tf_utils import compress_graph_def
30+
from tf2onnx.graph import ExternalTensorStorage
2931

3032

3133
class Tf2OnnxBackendTestBase(unittest.TestCase):
@@ -72,8 +74,9 @@ def run_onnxruntime(self, model_path, inputs, output_names):
7274
results = m.run(output_names, inputs)
7375
return results
7476

75-
def run_backend(self, g, outputs, input_dict):
76-
model_proto = g.make_model("test")
77+
def run_backend(self, g, outputs, input_dict, large_model=False):
78+
tensor_storage = ExternalTensorStorage() if large_model else None
79+
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
7780
model_path = self.save_onnx_model(model_proto, input_dict)
7881

7982
if self.config.backend == "onnxruntime":
@@ -86,7 +89,8 @@ def run_backend(self, g, outputs, input_dict):
8689

8790
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
8891
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
89-
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False):
92+
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
93+
large_model=False):
9094
# optional - passed to process_tf_graph
9195
if process_args is None:
9296
process_args = {}
@@ -121,7 +125,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
121125
concrete_func = tf.function(func, input_signature=tuple(input_tensors))
122126
concrete_func = concrete_func.get_concrete_function()
123127
graph_def = from_function(concrete_func,
124-
input_names=list(feed_dict.keys()), output_names=output_names_with_port)
128+
input_names=list(feed_dict.keys()),
129+
output_names=output_names_with_port,
130+
large_model=large_model)
125131
else:
126132
#
127133
# use graph to execute the tensorflow func
@@ -151,6 +157,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
151157

152158
tf_reset_default_graph()
153159
with tf_session() as sess:
160+
const_node_values = None
161+
if large_model:
162+
const_node_values = compress_graph_def(graph_def)
154163
tf.import_graph_def(graph_def, name='')
155164

156165
if self.config.is_debug_mode:
@@ -161,9 +170,11 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
161170
g = process_tf_graph(sess.graph, opset=self.config.opset,
162171
input_names=list(feed_dict.keys()),
163172
output_names=output_names_with_port,
164-
target=self.config.target, **process_args)
173+
target=self.config.target,
174+
const_node_values=const_node_values,
175+
**process_args)
165176
g = optimizer.optimize_graph(g)
166-
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict)
177+
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
167178

168179
for expected_val, actual_val in zip(expected, actual):
169180
if check_value:

tests/test_convert.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import unittest
99

1010
from tf2onnx import convert
11-
11+
from common import check_tf_min_version
1212

1313
def run_test_case(args):
1414
""" run case and clean up """
@@ -33,6 +33,18 @@ def test_convert_saved_model(self):
3333
'--output',
3434
'converted_saved_model.onnx']))
3535

36+
@check_tf_min_version("2.1")
37+
def test_convert_large_model(self):
38+
""" convert saved model to onnx large model format """
39+
self.assertTrue(run_test_case(['',
40+
'--large_model',
41+
'--saved-model',
42+
'tests/models/regression/saved_model',
43+
'--tag',
44+
'serve',
45+
'--output',
46+
'converted_saved_model.zip']))
47+
3648
def test_convert_graphdef(self):
3749
""" convert graphdef """
3850
self.assertTrue(run_test_case(['',

tf2onnx/convert.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tf2onnx.tfonnx import process_tf_graph
2323
from tf2onnx import constants, logging, utils, optimizer
2424
from tf2onnx import tf_loader
25+
from tf2onnx.graph import ExternalTensorStorage
26+
from tf2onnx.tf_utils import compress_graph_def
2527

2628
# pylint: disable=unused-argument
2729

@@ -53,6 +55,7 @@ def get_args():
5355
help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
5456
parser.add_argument("--checkpoint", help="input from checkpoint")
5557
parser.add_argument("--keras", help="input from keras model")
58+
parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true")
5659
parser.add_argument("--output", help="output model file")
5760
parser.add_argument("--inputs", help="model input_names")
5861
parser.add_argument("--outputs", help="model output_names")
@@ -129,7 +132,8 @@ def main():
129132
model_path = args.checkpoint
130133
if args.saved_model:
131134
graph_def, inputs, outputs = tf_loader.from_saved_model(
132-
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function)
135+
args.saved_model, args.inputs, args.outputs, args.tag,
136+
args.signature_def, args.concrete_function, args.large_model)
133137
model_path = args.saved_model
134138
if args.keras:
135139
graph_def, inputs, outputs = tf_loader.from_keras(
@@ -141,6 +145,9 @@ def main():
141145
logger.info("outputs: %s", outputs)
142146

143147
with tf.Graph().as_default() as tf_graph:
148+
const_node_values = None
149+
if args.large_model:
150+
const_node_values = compress_graph_def(graph_def)
144151
tf.import_graph_def(graph_def, name='')
145152
with tf_loader.tf_session(graph=tf_graph):
146153
g = process_tf_graph(tf_graph,
@@ -152,17 +159,24 @@ def main():
152159
shape_override=args.shape_override,
153160
input_names=inputs,
154161
output_names=outputs,
155-
inputs_as_nchw=args.inputs_as_nchw)
162+
inputs_as_nchw=args.inputs_as_nchw,
163+
const_node_values=const_node_values)
156164

157165
onnx_graph = optimizer.optimize_graph(g)
158-
model_proto = onnx_graph.make_model("converted from {}".format(model_path))
166+
167+
tensor_storage = ExternalTensorStorage() if args.large_model else None
168+
model_proto = onnx_graph.make_model("converted from {}".format(model_path), external_tensor_storage=tensor_storage)
159169

160170
# write onnx graph
161171
logger.info("")
162172
logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)
163173
if args.output:
164-
utils.save_protobuf(args.output, model_proto)
165-
logger.info("ONNX model is saved at %s", args.output)
174+
if args.large_model:
175+
utils.save_onnx_zip(args.output, model_proto, tensor_storage)
176+
logger.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", args.output)
177+
else:
178+
utils.save_protobuf(args.output, model_proto)
179+
logger.info("ONNX model is saved at %s", args.output)
166180
else:
167181
logger.info("To export ONNX model to file, please run with `--output` option")
168182

tf2onnx/tf_loader.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,19 @@ def inputs_without_resource(sess, input_names):
9696
return input_names
9797

9898

99-
def from_function(func, input_names, output_names):
99+
def from_function(func, input_names, output_names, large_model):
100+
if large_model:
101+
try:
102+
# For large models we use _convert_variables_to_constants_v2_impl as a hack
103+
from tensorflow.python.framework.convert_to_constants import \
104+
_convert_variables_to_constants_v2_impl # pylint: disable=protected-access
105+
except ImportError:
106+
# This internal method could disappear in different tf versions
107+
_not_implemented_tf_placeholder("_convert_variables_to_constants_v2_impl")()
108+
graph_def, _ = _convert_variables_to_constants_v2_impl(func, lower_control_flow=False,
109+
aggressive_inlining=False)
110+
return graph_def
111+
100112
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
101113
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
102114
# output_names = [i.name for i in frozen_func.outputs]
@@ -223,7 +235,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
223235
return frozen_graph, input_names, output_names
224236

225237

226-
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index):
238+
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def,
239+
concrete_function_index, large_model):
227240
"""Load tensorflow graph from saved_model."""
228241

229242
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
@@ -234,6 +247,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
234247
err_index = "Invalid concrete_function value: %i. Valid values are [0 to %i]"
235248
err_no_sig = "No signatures found in model. Try --concrete_function instead."
236249
err_sig_nomatch = "Specified signature not in model %s"
250+
err_large_model = "model exceeds maximum protobuf size of 2GB. Try running with --large_model flag."
237251

238252
if tag is None:
239253
tag = ['serve']
@@ -274,18 +288,26 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
274288
if output_names:
275289
outputs = list(set(output_names) & set(outputs))
276290

277-
frozen_graph = from_function(concrete_func, inputs, outputs)
291+
try:
292+
frozen_graph = from_function(concrete_func, inputs, outputs, large_model)
293+
except ValueError as e:
294+
if "exceeds maximum protobuf size of 2GB" in str(e):
295+
raise ValueError(err_large_model)
296+
else:
297+
raise e
298+
278299
return frozen_graph, inputs, outputs
279300

280301

281-
def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None):
302+
def from_saved_model(model_path, input_names, output_names, tag=None,
303+
signatures=None, concrete_function=None, large_model=False):
282304
"""Load tensorflow graph from saved_model."""
283305
if signatures is None:
284306
signatures = []
285307
tf_reset_default_graph()
286308
if is_tf2():
287309
frozen_graph, input_names, output_names = \
288-
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function)
310+
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model)
289311
else:
290312
with tf_session() as sess:
291313
frozen_graph, input_names, output_names = \
@@ -316,7 +338,7 @@ def from_keras(model_path, input_names, output_names):
316338
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
317339
if output_tensor.dtype != tf.dtypes.resource]
318340

319-
frozen_graph = from_function(concrete_func, input_names, output_names)
341+
frozen_graph = from_function(concrete_func, input_names, output_names, large_model=False)
320342
else:
321343
# Handles Keras when Eager mode is disabled.
322344
_keras.backend.clear_session()

tf2onnx/tfonnx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def run_rewriters(g, funcs, continue_on_error):
334334
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
335335
opset=None, custom_op_handlers=None, custom_rewriter=None,
336336
extra_opset=None, shape_override=None, inputs_as_nchw=None,
337-
input_names=None, output_names=None, is_subgraph=False):
337+
input_names=None, output_names=None, is_subgraph=False, const_node_values=None):
338338
"""Convert tensorflow graph to onnx graph.
339339
Args:
340340
tf_graph: tensorflow graph
@@ -349,6 +349,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
349349
inputs_as_nchw: transpose inputs in list from nchw to nchw
350350
input_names: list of input node names in graph, input name format as node_name:port_id
351351
output_names: list of output node names in graph, output name format as node_name:port_id
352+
const_node_values: a dict returned by compress_graph_def mapping node names to tensor values
352353
Return:
353354
onnx graph
354355
"""
@@ -377,7 +378,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
377378
if target is None:
378379
target = constants.DEFAULT_TARGET
379380

380-
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = tensorflow_to_onnx(tf_graph, shape_override)
381+
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, _ = \
382+
tensorflow_to_onnx(tf_graph, shape_override, const_node_values)
381383
if not is_subgraph:
382384
# make tf2onnx internal subgraphs from the tensorflow subgraphs
383385
ordered_func = resolve_functions(tf_graph)
@@ -387,7 +389,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
387389
fg = process_tf_graph(func, continue_on_error, False, target, opset,
388390
custom_op_handlers, custom_rewriter,
389391
extra_opset, shape_override, inputs_as_nchw,
390-
f_inputs_names, f_output_names, is_subgraph=True)
392+
f_inputs_names, f_output_names, is_subgraph=True,
393+
const_node_values=const_node_values)
391394
fg.graph_name = func.name
392395
fg.func_inputs = f_inputs_names
393396
set_function(func.name, fg)

tf2onnx/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import re
1414
import shutil
1515
import tempfile
16+
import zipfile
1617

1718
import requests
1819
from requests.adapters import HTTPAdapter
@@ -187,6 +188,11 @@ def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, incl
187188
save_protobuf(target_path + ".pbtxt", model_proto, as_text=True)
188189
return target_path
189190

191+
def save_onnx_zip(target_path, model_proto, external_tensor_storage):
192+
with zipfile.ZipFile(target_path, 'w') as z:
193+
z.writestr("__MODEL_PROTO.onnx", model_proto.SerializeToString())
194+
for k, v in external_tensor_storage.name_to_tensor_data.items():
195+
z.writestr(k, v)
190196

191197
def make_sure(bool_val, error_msg, *args):
192198
if not bool_val:

0 commit comments

Comments
 (0)