@@ -375,7 +375,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
375
375
initialized_tables = {}
376
376
outputs = self .output_names
377
377
tflite_path = None
378
- to_rename = None
378
+ to_rename = {}
379
379
if self .model_type in ["checkpoint" ]:
380
380
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
381
381
elif self .model_type in ["saved_model" ]:
@@ -400,6 +400,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
400
400
if utils .is_debug_mode ():
401
401
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
402
402
403
+ logger .info ("Input names %s" , input_names )
403
404
if tflite_path is not None :
404
405
inputs = {}
405
406
for k in input_names :
@@ -438,7 +439,7 @@ def run_tflite():
438
439
inputs = {}
439
440
for k in input_names :
440
441
v = self .input_names [k ]
441
- inputs [to_rename [ k ] ] = tf .constant (self .make_input (v ))
442
+ inputs [to_rename . get ( k , k ) ] = tf .constant (self .make_input (v ))
442
443
tf_func = tf .function (concrete_func )
443
444
logger .info ("Running TF" )
444
445
tf_results_d = tf_func (** inputs )
@@ -507,6 +508,7 @@ def run_tflite():
507
508
elif self .run_tf_frozen :
508
509
if self .tf_profile is not None :
509
510
tf .profiler .experimental .start (self .tf_profile )
511
+ logger .info ("TF inputs %s" , list (inputs .keys ()))
510
512
tf_results = self .run_tensorflow (sess , inputs )
511
513
if self .tf_profile is not None :
512
514
tf .profiler .experimental .stop ()
@@ -553,11 +555,9 @@ def run_tflite():
553
555
try :
554
556
onnx_results = None
555
557
if backend == "onnxruntime" :
556
- if to_rename is None :
557
- struc_outputs = self .output_names
558
- else :
559
- struc_outputs = [to_rename .get (k , k ) for k in self .output_names ]
558
+ struc_outputs = [to_rename .get (k , k ) for k in self .output_names ]
560
559
struc_inputs = {to_rename .get (k , k ): v for k , v in inputs .items ()}
560
+ logger .info ("ORT inputs %s" , list (struc_inputs .keys ()))
561
561
onnx_results = self .run_onnxruntime (
562
562
name , model_proto , struc_inputs , struc_outputs , external_tensor_storage )
563
563
else :
0 commit comments