20
20
import onnx
21
21
from common import get_test_config
22
22
from tfjs_runner import run_tfjs
23
+ from tf2onnx import constants
23
24
from tf2onnx import utils
24
25
from tf2onnx .tfonnx import process_tf_graph
25
26
from tf2onnx import optimizer
@@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
366
367
graph_def_path = os .path .join (self .test_data_directory , self ._testMethodName + "_after_tf_optimize.pb" )
367
368
utils .save_protobuf (graph_def_path , graph_def )
368
369
self .logger .debug ("created file %s" , graph_def_path )
370
+ tfl_process_args = process_args .copy ()
369
371
370
372
if test_tfjs :
371
373
tfjs_path = self .convert_to_tfjs (graph_def_path , output_names_with_port )
@@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
395
397
g = optimizer .optimize_graph (g , catch_errors = False )
396
398
actual = self .run_backend (g , output_names_with_port , onnx_feed_dict , large_model ,
397
399
use_custom_ops = use_custom_ops )
400
+ if 'outputs_as_nchw' in tfl_process_args :
401
+ for output_name in tfl_process_args ['outputs_as_nchw' ]:
402
+ i = output_names_with_port .index (output_name )
403
+ actual [i ] = np .transpose (actual [i ], constants .NCHW_TO_NHWC )
398
404
399
405
self .assert_results_equal (expected , actual , rtol , atol , mtol , check_value , check_shape , check_dtype )
400
406
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
@@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
410
416
if run_tfl_consistency_test :
411
417
self .assert_results_equal (expected , tfl_res , rtol , atol , mtol , check_value , check_shape , check_dtype )
412
418
413
- tfl_process_args = process_args .copy ()
414
419
if 'inputs_as_nchw' in tfl_process_args :
415
420
nchw_inps_with_port = tfl_process_args ['inputs_as_nchw' ]
416
421
tfl_process_args ['inputs_as_nchw' ] = [i .split (':' )[0 ] for i in nchw_inps_with_port ]
417
422
input_names_without_port = [inp .split (':' )[0 ] for inp in feed_dict .keys ()]
418
-
423
+ if 'outputs_as_nchw' in tfl_process_args :
424
+ nchw_outps_with_port = tfl_process_args ['outputs_as_nchw' ]
425
+ tfl_process_args ['outputs_as_nchw' ] = [i .split (':' )[0 ] for i in nchw_outps_with_port ]
426
+ output_names_with_port = [i .split (':' )[0 ] for i in nchw_outps_with_port ]
419
427
g = process_tf_graph (None , opset = self .config .opset ,
420
428
input_names = input_names_without_port ,
421
429
output_names = tfl_outputs ,
@@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
427
435
onnx_feed_dict_without_port = {k .split (':' )[0 ]: v for k , v in onnx_feed_dict .items ()}
428
436
onnx_tfl_res = self .run_backend (g , tfl_outputs , onnx_feed_dict_without_port ,
429
437
postfix = "_from_tflite" , use_custom_ops = use_custom_ops )
438
+ if 'outputs_as_nchw' in tfl_process_args :
439
+ for output_name in tfl_process_args ['outputs_as_nchw' ]:
440
+ i = output_names_with_port .index (output_name )
441
+ onnx_tfl_res [i ] = np .transpose (onnx_tfl_res [i ], constants .NCHW_TO_NHWC )
430
442
431
443
self .assert_results_equal (tfl_res , onnx_tfl_res , rtol , atol , mtol , check_value , check_shape , check_dtype )
432
444
self .assert_shapes_correct (g , self .config .allow_missing_shapes , not self .config .skip_onnx_checker )
@@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
456
468
g = optimizer .optimize_graph (g )
457
469
onnx_tfjs_res = self .run_backend (g , None , onnx_feed_dict , large_model ,
458
470
postfix = "_from_tfjs" , use_custom_ops = use_custom_ops )
471
+ if 'outputs_as_nchw' in tfl_process_args :
472
+ for output_name in tfl_process_args ['outputs_as_nchw' ]:
473
+ i = output_names_with_port .index (output_name )
474
+ onnx_tfjs_res [i ] = np .transpose (onnx_tfjs_res [i ], constants .NCHW_TO_NHWC )
459
475
460
476
self .assert_results_equal (tfjs_res , onnx_tfjs_res , rtol , atol , mtol , check_value , check_shape ,
461
477
check_dtype = False )
0 commit comments