@@ -296,7 +296,7 @@ def run_tensorflow(self, sess, inputs):
296
296
return result
297
297
298
298
def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ,
299
- const_node_values = None , initialized_tables = None , tflite_path = None ):
299
+ const_node_values = None , initialized_tables = None , tflite_path = None , tensors_to_rename = None ):
300
300
"""Convert graph to tensorflow."""
301
301
if extra_opset is None :
302
302
extra_opset = []
@@ -306,7 +306,8 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
306
306
extra_opset = extra_opset , target = Test .target , shape_override = shape_override ,
307
307
input_names = input_names , output_names = self .output_names ,
308
308
const_node_values = const_node_values , initialized_tables = initialized_tables ,
309
- tflite_path = tflite_path , dequantize = self .dequantize )
309
+ tflite_path = tflite_path , dequantize = self .dequantize ,
310
+ tensors_to_rename = tensors_to_rename )
310
311
311
312
def run_caffe2 (self , name , model_proto , inputs ):
312
313
"""Run test again caffe2 backend."""
@@ -320,7 +321,7 @@ def run_caffe2(self, name, model_proto, inputs):
320
321
self .onnx_runtime = time .time () - start
321
322
return results
322
323
323
- def run_onnxruntime (self , name , model_proto , inputs , external_tensor_storage = None ):
324
+ def run_onnxruntime (self , name , model_proto , inputs , outputs , external_tensor_storage = None ):
324
325
"""Run test against onnxruntime backend."""
325
326
import onnxruntime as rt
326
327
model_path = utils .save_onnx_model (TEMP_DIR , name , inputs , model_proto , include_test_data = True ,
@@ -334,11 +335,11 @@ def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=Non
334
335
m = rt .InferenceSession (model_path , opt )
335
336
else :
336
337
m = rt .InferenceSession (model_path )
337
- results = m .run (self . output_names , inputs )
338
+ results = m .run (outputs , inputs )
338
339
if self .perf :
339
340
start = time .time ()
340
341
for _ in range (PERFITER ):
341
- _ = m .run (self . output_names , inputs )
342
+ _ = m .run (outputs , inputs )
342
343
self .onnx_runtime = time .time () - start
343
344
return results
344
345
@@ -371,19 +372,20 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
371
372
initialized_tables = {}
372
373
outputs = self .output_names
373
374
tflite_path = None
375
+ to_rename = None
374
376
if self .model_type in ["checkpoint" ]:
375
377
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
376
378
elif self .model_type in ["saved_model" ]:
377
- loaded = tf_loader .from_saved_model (model_path , input_names , outputs , self .tag , self .signatures ,
379
+ loaded = tf_loader .from_saved_model (model_path , None , None , self .tag , self .signatures ,
378
380
self .concrete_function , self .large_model ,
379
381
return_concrete_func = not self .run_tf_frozen ,
380
- return_initialized_tables = True )
382
+ return_initialized_tables = True , return_tensors_to_rename = True )
381
383
if not self .run_tf_frozen :
382
384
# Must maintain ref to imported since concrete_func uses weak refs
383
385
# pylint: disable=unused-variable
384
- graph_def , input_names , outputs , concrete_func , imported , initialized_tables = loaded
386
+ graph_def , input_names , outputs , concrete_func , imported , initialized_tables , to_rename = loaded
385
387
else :
386
- graph_def , input_names , outputs , initialized_tables = loaded
388
+ graph_def , input_names , outputs , initialized_tables , to_rename = loaded
387
389
elif self .model_type in ["keras" ]:
388
390
graph_def , input_names , outputs = tf_loader .from_keras (model_path , input_names , outputs )
389
391
elif self .model_type in ["tflite" ]:
@@ -434,10 +436,8 @@ def run_tflite():
434
436
# If there is only a single output a dict might not be returned
435
437
if isinstance (tf_results_d , tf .Tensor ):
436
438
tf_results = [tf_results_d ]
437
- elif self .structured_outputs is None :
438
- tf_results = list (tf_results_d .values ())
439
439
else :
440
- tf_results = [tf_results_d [output ] for output in self . structured_outputs ]
440
+ tf_results = [tf_results_d [k ] for k in sorted ( tf_results_d . keys ()) ]
441
441
tf_results = [tf_res .numpy () for tf_res in tf_results ]
442
442
if self .perf :
443
443
logger .info ("Running TF perf" )
@@ -507,7 +507,8 @@ def run_tflite():
507
507
onnx_graph = self .to_onnx (tf_graph , opset = opset , extra_opset = extra_opset ,
508
508
shape_override = shape_override , input_names = inputs .keys (),
509
509
const_node_values = const_node_values ,
510
- initialized_tables = initialized_tables , tflite_path = tflite_path )
510
+ initialized_tables = initialized_tables , tflite_path = tflite_path ,
511
+ tensors_to_rename = to_rename )
511
512
onnx_graph = optimizer .optimize_graph (onnx_graph )
512
513
print ("ONNX" , onnx_graph .dump_node_statistics ())
513
514
external_tensor_storage = ExternalTensorStorage () if self .large_model else None
@@ -532,7 +533,11 @@ def run_tflite():
532
533
if backend == "caffe2" :
533
534
onnx_results = self .run_caffe2 (name , model_proto , inputs )
534
535
elif backend == "onnxruntime" :
535
- onnx_results = self .run_onnxruntime (name , model_proto , inputs , external_tensor_storage )
536
+ if to_rename is None :
537
+ struc_outputs = outputs
538
+ else :
539
+ struc_outputs = [to_rename .get (k , k ) for k in outputs ]
540
+ onnx_results = self .run_onnxruntime (name , model_proto , inputs , struc_outputs , external_tensor_storage )
536
541
else :
537
542
raise ValueError ("unknown backend" )
538
543
logger .info ("Run_ONNX OK" )
0 commit comments