Skip to content

Commit fb8055b

Browse files
committed
add tests for outputs as nchw
Signed-off-by: Deyu Huang <[email protected]>
1 parent c7b11ad commit fb8055b

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

tests/backend_test_base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import onnx
2121
from common import get_test_config
2222
from tfjs_runner import run_tfjs
23+
from tf2onnx import constants
2324
from tf2onnx import utils
2425
from tf2onnx.tfonnx import process_tf_graph
2526
from tf2onnx import optimizer
@@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
366367
graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
367368
utils.save_protobuf(graph_def_path, graph_def)
368369
self.logger.debug("created file %s", graph_def_path)
370+
tfl_process_args = process_args.copy()
369371

370372
if test_tfjs:
371373
tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port)
@@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
395397
g = optimizer.optimize_graph(g, catch_errors=False)
396398
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
397399
use_custom_ops=use_custom_ops)
400+
if 'outputs_as_nchw' in tfl_process_args:
401+
for output_name in tfl_process_args['outputs_as_nchw']:
402+
i = output_names_with_port.index(output_name)
403+
actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC)
398404

399405
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
400406
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
410416
if run_tfl_consistency_test:
411417
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
412418

413-
tfl_process_args = process_args.copy()
414419
if 'inputs_as_nchw' in tfl_process_args:
415420
nchw_inps_with_port = tfl_process_args['inputs_as_nchw']
416421
tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port]
417422
input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()]
418-
423+
if 'outputs_as_nchw' in tfl_process_args:
424+
nchw_outps_with_port = tfl_process_args['outputs_as_nchw']
425+
tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port]
426+
output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port]
419427
g = process_tf_graph(None, opset=self.config.opset,
420428
input_names=input_names_without_port,
421429
output_names=tfl_outputs,
@@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
427435
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
428436
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
429437
postfix="_from_tflite", use_custom_ops=use_custom_ops)
438+
if 'outputs_as_nchw' in tfl_process_args:
439+
for output_name in tfl_process_args['outputs_as_nchw']:
440+
i = output_names_with_port.index(output_name)
441+
onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC)
430442

431443
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
432444
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
@@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
456468
g = optimizer.optimize_graph(g)
457469
onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model,
458470
postfix="_from_tfjs", use_custom_ops=use_custom_ops)
471+
if 'outputs_as_nchw' in tfl_process_args:
472+
for output_name in tfl_process_args['outputs_as_nchw']:
473+
i = output_names_with_port.index(output_name)
474+
onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC)
459475

460476
self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape,
461477
check_dtype=False)

tests/test_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,8 @@ def func(x):
738738
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
739739
process_args={"inputs_as_nchw": [_INPUT],
740740
"outputs_as_nchw": [_OUTPUT]},
741-
onnx_feed_dict={_INPUT: x_val_for_onnx})
741+
onnx_feed_dict={_INPUT: x_val_for_onnx},
742+
graph_validator=lambda g: check_op_count(g, "Transpose", 0, disabled=False))
742743

743744
@skip_tflite("TFlite adds ops that obscure pattern")
744745
@check_tf_min_version("1.15")
@@ -4584,7 +4585,7 @@ def test_add2(self):
45844585
def func(x):
45854586
x_ = tf.add(x, x)
45864587
return tf.identity(x_, name=_TFOUTPUT)
4587-
self._run_test_case(func, [_2TPUT], {_INPUT: x_val})
4588+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
45884589

45894590
@check_opset_min_version(11, "CumSum")
45904591
def test_cumsum(self):

0 commit comments

Comments
 (0)