Skip to content

Commit e12d3ea

Browse files
Allow for cutting and renaming of IO. Use structured outputs by default (#1355)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 0b9c395 commit e12d3ea

File tree

6 files changed

+148
-64
lines changed

6 files changed

+148
-64
lines changed

tests/run_pretrained_models.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def run_tensorflow(self, sess, inputs):
296296
return result
297297

298298
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
299-
const_node_values=None, initialized_tables=None, tflite_path=None):
299+
const_node_values=None, initialized_tables=None, tflite_path=None, tensors_to_rename=None):
300300
"""Convert graph to tensorflow."""
301301
if extra_opset is None:
302302
extra_opset = []
@@ -306,7 +306,8 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
306306
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
307307
input_names=input_names, output_names=self.output_names,
308308
const_node_values=const_node_values, initialized_tables=initialized_tables,
309-
tflite_path=tflite_path, dequantize=self.dequantize)
309+
tflite_path=tflite_path, dequantize=self.dequantize,
310+
tensors_to_rename=tensors_to_rename)
310311

311312
def run_caffe2(self, name, model_proto, inputs):
312313
"""Run test again caffe2 backend."""
@@ -320,7 +321,7 @@ def run_caffe2(self, name, model_proto, inputs):
320321
self.onnx_runtime = time.time() - start
321322
return results
322323

323-
def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=None):
324+
def run_onnxruntime(self, name, model_proto, inputs, outputs, external_tensor_storage=None):
324325
"""Run test against onnxruntime backend."""
325326
import onnxruntime as rt
326327
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True,
@@ -334,11 +335,11 @@ def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=Non
334335
m = rt.InferenceSession(model_path, opt)
335336
else:
336337
m = rt.InferenceSession(model_path)
337-
results = m.run(self.output_names, inputs)
338+
results = m.run(outputs, inputs)
338339
if self.perf:
339340
start = time.time()
340341
for _ in range(PERFITER):
341-
_ = m.run(self.output_names, inputs)
342+
_ = m.run(outputs, inputs)
342343
self.onnx_runtime = time.time() - start
343344
return results
344345

@@ -371,19 +372,20 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
371372
initialized_tables = {}
372373
outputs = self.output_names
373374
tflite_path = None
375+
to_rename = None
374376
if self.model_type in ["checkpoint"]:
375377
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
376378
elif self.model_type in ["saved_model"]:
377-
loaded = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag, self.signatures,
379+
loaded = tf_loader.from_saved_model(model_path, None, None, self.tag, self.signatures,
378380
self.concrete_function, self.large_model,
379381
return_concrete_func=not self.run_tf_frozen,
380-
return_initialized_tables=True)
382+
return_initialized_tables=True, return_tensors_to_rename=True)
381383
if not self.run_tf_frozen:
382384
# Must maintain ref to imported since concrete_func uses weak refs
383385
# pylint: disable=unused-variable
384-
graph_def, input_names, outputs, concrete_func, imported, initialized_tables = loaded
386+
graph_def, input_names, outputs, concrete_func, imported, initialized_tables, to_rename = loaded
385387
else:
386-
graph_def, input_names, outputs, initialized_tables = loaded
388+
graph_def, input_names, outputs, initialized_tables, to_rename = loaded
387389
elif self.model_type in ["keras"]:
388390
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
389391
elif self.model_type in ["tflite"]:
@@ -434,10 +436,8 @@ def run_tflite():
434436
# If there is only a single output a dict might not be returned
435437
if isinstance(tf_results_d, tf.Tensor):
436438
tf_results = [tf_results_d]
437-
elif self.structured_outputs is None:
438-
tf_results = list(tf_results_d.values())
439439
else:
440-
tf_results = [tf_results_d[output] for output in self.structured_outputs]
440+
tf_results = [tf_results_d[k] for k in sorted(tf_results_d.keys())]
441441
tf_results = [tf_res.numpy() for tf_res in tf_results]
442442
if self.perf:
443443
logger.info("Running TF perf")
@@ -507,7 +507,8 @@ def run_tflite():
507507
onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset,
508508
shape_override=shape_override, input_names=inputs.keys(),
509509
const_node_values=const_node_values,
510-
initialized_tables=initialized_tables, tflite_path=tflite_path)
510+
initialized_tables=initialized_tables, tflite_path=tflite_path,
511+
tensors_to_rename=to_rename)
511512
onnx_graph = optimizer.optimize_graph(onnx_graph)
512513
print("ONNX", onnx_graph.dump_node_statistics())
513514
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
@@ -532,7 +533,11 @@ def run_tflite():
532533
if backend == "caffe2":
533534
onnx_results = self.run_caffe2(name, model_proto, inputs)
534535
elif backend == "onnxruntime":
535-
onnx_results = self.run_onnxruntime(name, model_proto, inputs, external_tensor_storage)
536+
if to_rename is None:
537+
struc_outputs = self.output_names
538+
else:
539+
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
540+
onnx_results = self.run_onnxruntime(name, model_proto, inputs, struc_outputs, external_tensor_storage)
536541
else:
537542
raise ValueError("unknown backend")
538543
logger.info("Run_ONNX OK")

tests/test_convert.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ def test_convert_checkpoint(self):
8888
'--output',
8989
'converted_checkpoint.onnx']))
9090

91+
def test_convert_graphdef_cut_input(self):
92+
""" convert graphdef, change input to start from Mul:0 and rename it """
93+
self.assertTrue(run_test_case(['',
94+
'--input',
95+
'tests/models/regression/graphdef/frozen.pb',
96+
'--inputs',
97+
'Mul:0',
98+
'--rename-inputs',
99+
'new_input',
100+
'--outputs',
101+
'pred:0',
102+
'--output',
103+
'converted_graphdef_cut_input.onnx']))
104+
105+
def test_convert_graphdef_cut_output(self):
106+
""" convert graphdef, change output to Mul:0 and rename it """
107+
self.assertTrue(run_test_case(['',
108+
'--input',
109+
'tests/models/regression/graphdef/frozen.pb',
110+
'--inputs',
111+
'X:0',
112+
'--rename-outputs',
113+
'new_output',
114+
'--outputs',
115+
'Mul:0',
116+
'--output',
117+
'converted_graphdef_cut_output.onnx']))
118+
91119

92120
if __name__ == '__main__':
93121
unittest.main()

tf2onnx/convert.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@ def get_args():
5858
parser.add_argument("--tflite", help="input from tflite model")
5959
parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true")
6060
parser.add_argument("--output", help="output model file")
61-
parser.add_argument("--inputs", help="model input_names")
62-
parser.add_argument("--outputs", help="model output_names")
61+
parser.add_argument("--inputs", help="model input_names (optional for saved_model, keras, and tflite)")
62+
parser.add_argument("--outputs", help="model output_names (optional for saved_model, keras, and tflite)")
6363
parser.add_argument("--ignore_default", help="comma-separated list of names of PlaceholderWithDefault "
6464
"ops to change into Placeholder ops")
6565
parser.add_argument("--use_default", help="comma-separated list of names of PlaceholderWithDefault ops to "
6666
"change into Identity ops using their default value")
67+
parser.add_argument("--rename-inputs", help="input names to use in final model (optional)")
68+
parser.add_argument("--rename-outputs", help="output names to use in final model (optional)")
6769
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
6870
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
6971
action="store_true")
@@ -87,7 +89,7 @@ def get_args():
8789
# for backward compativility
8890
args.graphdef = args.input
8991
if args.graphdef or args.checkpoint:
90-
if not args.input and not args.outputs:
92+
if not args.inputs or not args.outputs:
9193
parser.error("graphdef and checkpoint models need to provide inputs and outputs")
9294
if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras, args.tflite]):
9395
parser.print_help()
@@ -100,6 +102,10 @@ def get_args():
100102
args.ignore_default = args.ignore_default.split(",")
101103
if args.use_default:
102104
args.use_default = args.use_default.split(",")
105+
if args.rename_outputs:
106+
args.rename_outputs = args.rename_outputs.split(",")
107+
if args.rename_inputs:
108+
args.rename_inputs = args.rename_inputs.split(",")
103109
if args.inputs_as_nchw:
104110
args.inputs_as_nchw = args.inputs_as_nchw.split(",")
105111
if args.target:
@@ -135,6 +141,7 @@ def main():
135141
tflite_path = None
136142
custom_ops = {}
137143
initialized_tables = None
144+
tensors_to_rename = {}
138145
if args.custom_ops:
139146
using_tf_opset = False
140147
for op in args.custom_ops.split(","):
@@ -162,9 +169,9 @@ def main():
162169
graph_def, inputs, outputs = tf_loader.from_checkpoint(args.checkpoint, args.inputs, args.outputs)
163170
model_path = args.checkpoint
164171
if args.saved_model:
165-
graph_def, inputs, outputs, initialized_tables = tf_loader.from_saved_model(
166-
args.saved_model, args.inputs, args.outputs, args.tag,
167-
args.signature_def, args.concrete_function, args.large_model, return_initialized_tables=True)
172+
graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model(
173+
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function,
174+
args.large_model, return_initialized_tables=True, return_tensors_to_rename=True)
168175
model_path = args.saved_model
169176
if args.keras:
170177
graph_def, inputs, outputs = tf_loader.from_keras(
@@ -181,6 +188,11 @@ def main():
181188
logger.info("inputs: %s", inputs)
182189
logger.info("outputs: %s", outputs)
183190

191+
if args.rename_inputs:
192+
tensors_to_rename.update(zip(inputs, args.rename_inputs))
193+
if args.rename_outputs:
194+
tensors_to_rename.update(zip(outputs, args.rename_outputs))
195+
184196
tf_graph = None
185197
const_node_values = None
186198
if graph_def is not None:
@@ -206,6 +218,7 @@ def main():
206218
ignore_default=args.ignore_default,
207219
use_default=args.use_default,
208220
const_node_values=const_node_values,
221+
tensors_to_rename=tensors_to_rename,
209222
initialized_tables=initialized_tables,
210223
tflite_path=tflite_path,
211224
dequantize=args.dequantize)
@@ -218,6 +231,8 @@ def main():
218231
# write onnx graph
219232
logger.info("")
220233
logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)
234+
logger.info("Model inputs: %s", onnx_graph.input_names)
235+
logger.info("Model outputs: %s", onnx_graph.outputs)
221236
if args.output:
222237
if args.large_model:
223238
utils.save_onnx_zip(args.output, model_proto, tensor_storage)

tf2onnx/shape_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def infer_shape_for_op(op):
100100
op.outputs[0].set_shape(new_shape)
101101
logger.debug("set placeholder op [%s] with new shape %s", op.outputs[0].name, new_shape)
102102
return True
103-
logger.warning("Shape of placeholder '%s' is unknown, treated it as a scalar. Please use the --input flag "
103+
logger.warning("Shape of placeholder '%s' is unknown, treated it as a scalar. Please use the --inputs flag "
104104
"and append the shape to the input name if this input is not a scalar.", op.name)
105105
op.outputs[0].set_shape([])
106106
return True

0 commit comments

Comments
 (0)