From 0efd815b78ab45a12db7e4e21e22a9af8e613d0a Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Fri, 24 Jun 2022 17:02:21 +0800 Subject: [PATCH 01/10] add output_as_nchw Signed-off-by: Deyu Huang --- tests/test_backend.py | 15 +++++++++++++ tf2onnx/convert.py | 4 ++++ tf2onnx/tfonnx.py | 51 +++++++++++++++++++++++++++++++++---------- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 521978870..3876da035 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -725,6 +725,21 @@ def func(x): process_args={"inputs_as_nchw": [_INPUT]}, onnx_feed_dict={_INPUT: x_val_for_onnx}) + def test_conv2d_with_output_transpose(self): + x_shape = [2, 32, 32, 3] + kernel_shape = [3, 3, 3, 3] + x_val = make_xval(x_shape) + x_val_for_onnx = x_val.transpose(NHWC_TO_NCHW) + + def func(x): + kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel') + conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME") + return tf.identity(conv, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, + process_args={"inputs_as_nchw": [_INPUT], + "outputs_as_nchw": [_OUTPUT]}, + onnx_feed_dict={_INPUT: x_val_for_onnx}) + @skip_tflite("TFlite adds ops that obscure pattern") @check_tf_min_version("1.15") def test_conv1d_dilations_rewriter(self): diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 0a1069496..4c1a6ef4e 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -86,6 +86,7 @@ def get_args(): # experimental parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw") + parser.add_argument("--outputs-as-nchw", help="transpose outputs as from nhwc to nchw") args = parser.parse_args() args.shape_override = None @@ -112,6 +113,8 @@ def get_args(): args.rename_inputs = args.rename_inputs.split(",") if args.inputs_as_nchw: args.inputs_as_nchw = args.inputs_as_nchw.split(",") + if args.outputs_as_nchw: + args.outputs_as_nchw = args.outputs_as_nchw.split(",") if args.target: args.target = args.target.split(",") if args.signature_def: @@ -275,6 +278,7 @@ def main(): input_names=inputs, output_names=outputs, inputs_as_nchw=args.inputs_as_nchw, + outputs_as_nchw=args.outputs_as_nchw, large_model=args.large_model, tensors_to_rename=tensors_to_rename, ignore_default=args.ignore_default, diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 1a351cfcb..438603a9c 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -329,6 +329,28 @@ def transpose_inputs(ctx, inputs_as_nchw): ops.append(node) ctx.reset_nodes(ops) +def transpose_outputs(ctx, outputs_as_nchw): + """Insert a transpose from NHWC to NCHW on model input on users request.""" + ops = [] + for node in ctx.get_nodes(): + for idx, output_name in enumerate(node.output): + if output_name in outputs_as_nchw: + shape = ctx.get_shape(output_name) + if len(shape) != len(constants.NHWC_TO_NCHW): + logger.warning("transpose_input for %s: shape must be rank 4, ignored" % output_name) + ops.append(node) + continue + # insert transpose + op_name = utils.make_name(node.name) + transpose = ctx.insert_new_node_on_input("Transpose", output_name, name=op_name) + transpose.set_attr("perm", constants.NCHW_TO_NHWC) + ctx.copy_shape(output_name, transpose.output[0]) + ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW]) + ops.append(transpose) + ops.append(node) + continue + ops.append(node) + ctx.reset_nodes(ops) def topological_sort(g, continue_on_error): ops = g.get_nodes() @@ -376,7 +398,7 @@ def run_rewriters(g, funcs, continue_on_error): def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None, - extra_opset=None, shape_override=None, inputs_as_nchw=None, + extra_opset=None, shape_override=None, inputs_as_nchw=None, outputs_as_nchw=None, input_names=None, output_names=None, ignore_default=None, use_default=None, is_subgraph=False, const_node_values=None, tensors_to_rename=None, initialized_tables=None, tflite_path=None, dequantize=False, tfjs_path=None): @@ -391,7 +413,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw input_names: list of input node names in graph, input name format as node_name:port_id. Optional. output_names: list of output node names in graph, format is node_name:port_id. Optional for tflite. ignore_default: list of node names of PlaceholderWithDefault ops to change into Placeholder ops @@ -421,6 +444,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No clear_functions() if inputs_as_nchw is None: inputs_as_nchw = [] + if outputs_as_nchw is None: + outputs_as_nchw = [] is_tflite = False if tflite_path is not None: @@ -435,8 +460,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No for g in [main_g] + subgraphs: g.set_config(target, opset, extra_opset) - g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, tensors_to_rename, is_tflite, dequantize) + g = process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, tensors_to_rename, is_tflite, dequantize) return g @@ -476,24 +501,24 @@ def graphs_from_tf(tf_graph, input_names, output_names, shape_override=None, con return main_g, subgraphs -def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False): +def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False): if tensors_to_rename is not None: main_g.rename_tensors(tensors_to_rename) inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw] + outputs_as_nchw = [tensors_to_rename.get(t, t) for t in outputs_as_nchw] for g in subgraphs: - fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, is_tflite, dequantize) + fg = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, is_tflite, dequantize) set_function(fg.graph_name, fg) - g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, - initialized_tables, is_tflite, - dequantize) + g = process_parsed_graph(main_g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, + custom_rewriter, initialized_tables, is_tflite, dequantize) return g -def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, +def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, custom_rewriter, initialized_tables, is_tflite=False, dequantize=False): op_cnt, attr_cnt = g.dump_node_statistics(include_attrs=True, include_subgraphs=False) @@ -549,6 +574,8 @@ def compat_handler(ctx, node, **kwargs): if inputs_as_nchw: transpose_inputs(g, inputs_as_nchw) + if outputs_as_nchw: + transpose_outputs(g, outputs_as_nchw) # pre-processing graph rewrites # bi-directional re-writer should be placed after single directional re-writer From 118b7b7c365a124a2b51bd0c003c9de0caaaaa04 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Fri, 24 Jun 2022 19:43:00 +0800 Subject: [PATCH 02/10] fix typo Signed-off-by: Deyu Huang --- tf2onnx/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 4c1a6ef4e..eb37a89b8 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -91,7 +91,7 @@ def get_args(): args.shape_override = None if args.input: - # for backward compativility + # for backward compatibility args.graphdef = args.input if args.graphdef or args.checkpoint: if not args.inputs or not args.outputs: From da43704b47195e0943da8befc820d46dd7f4212a Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Sun, 26 Jun 2022 21:17:53 +0800 Subject: [PATCH 03/10] fix pylint Signed-off-by: Deyu Huang --- tests/test_backend.py | 4 ++-- tf2onnx/tfonnx.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 3876da035..5d8973467 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -725,7 +725,7 @@ def func(x): process_args={"inputs_as_nchw": [_INPUT]}, onnx_feed_dict={_INPUT: x_val_for_onnx}) - def test_conv2d_with_output_transpose(self): + def test_conv2d_with_output_transpose(self): x_shape = [2, 32, 32, 3] kernel_shape = [3, 3, 3, 3] x_val = make_xval(x_shape) @@ -4584,7 +4584,7 @@ def test_add2(self): def func(x): x_ = tf.add(x, x) return tf.identity(x_, name=_TFOUTPUT) - self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) + self._run_test_case(func, [_2TPUT], {_INPUT: x_val}) @check_opset_min_version(11, "CumSum") def test_cumsum(self): diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 438603a9c..d44324c13 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -330,14 +330,14 @@ def transpose_inputs(ctx, inputs_as_nchw): ctx.reset_nodes(ops) def transpose_outputs(ctx, outputs_as_nchw): - """Insert a transpose from NHWC to NCHW on model input on users request.""" + """Insert a transpose from NHWC to NCHW on model output on users request.""" ops = [] for node in ctx.get_nodes(): for idx, output_name in enumerate(node.output): if output_name in outputs_as_nchw: shape = ctx.get_shape(output_name) if len(shape) != len(constants.NHWC_TO_NCHW): - logger.warning("transpose_input for %s: shape must be rank 4, ignored" % output_name) + logger.warning("transpose_output for %s: shape must be rank 4, ignored" % output_name) ops.append(node) continue # insert transpose From 9fade6c3fe29cb023503285b8714d96c77f7508b Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Wed, 6 Jul 2022 17:22:46 +0800 Subject: [PATCH 04/10] fix node replace logic Signed-off-by: Deyu Huang --- tf2onnx/tfonnx.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index d44324c13..c2c881e77 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -333,7 +333,7 @@ def transpose_outputs(ctx, outputs_as_nchw): """Insert a transpose from NHWC to NCHW on model output on users request.""" ops = [] for node in ctx.get_nodes(): - for idx, output_name in enumerate(node.output): + for output_name in node.output: if output_name in outputs_as_nchw: shape = ctx.get_shape(output_name) if len(shape) != len(constants.NHWC_TO_NCHW): @@ -342,9 +342,10 @@ def transpose_outputs(ctx, outputs_as_nchw): continue # insert transpose op_name = utils.make_name(node.name) - transpose = ctx.insert_new_node_on_input("Transpose", output_name, name=op_name) - transpose.set_attr("perm", constants.NCHW_TO_NHWC) - ctx.copy_shape(output_name, transpose.output[0]) + transpose = ctx.insert_new_node_on_output("Transpose", node.input[0], name=op_name) + transpose.set_attr("perm", constants.NHWC_TO_NCHW) + ctx.copy_shape(node.output[0], transpose.output[0]) + ctx.set_shape(transpose.output[0], np.array(shape)[constants.NHWC_TO_NCHW]) ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW]) ops.append(transpose) ops.append(node) @@ -503,7 +504,6 @@ def graphs_from_tf(tf_graph, input_names, output_names, shape_override=None, con def process_graphs(main_g, subgraphs, custom_op_handlers, inputs_as_nchw, outputs_as_nchw, continue_on_error, custom_rewriter, initialized_tables, tensors_to_rename, is_tflite=False, dequantize=False): - if tensors_to_rename is not None: main_g.rename_tensors(tensors_to_rename) inputs_as_nchw = [tensors_to_rename.get(t, t) for t in inputs_as_nchw] From 28e6c59f7972a24b3893e123615ede112f85d0f5 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Wed, 6 Jul 2022 17:26:59 +0800 Subject: [PATCH 05/10] add tests for outputs as nchw Signed-off-by: Deyu Huang --- tests/backend_test_base.py | 20 ++++++++++++++++++-- tests/test_backend.py | 5 +++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index f39c398b1..dd7bf66ce 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -20,6 +20,7 @@ import onnx from common import get_test_config from tfjs_runner import run_tfjs +from tf2onnx import constants from tf2onnx import utils from tf2onnx.tfonnx import process_tf_graph from tf2onnx import optimizer @@ -366,6 +367,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit graph_def_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(graph_def_path, graph_def) self.logger.debug("created file %s", graph_def_path) + tfl_process_args = process_args.copy() if test_tfjs: tfjs_path = self.convert_to_tfjs(graph_def_path, output_names_with_port) @@ -395,6 +397,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit g = optimizer.optimize_graph(g, catch_errors=False) actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model, use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + actual[i] = np.transpose(actual[i], constants.NCHW_TO_NHWC) self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -410,12 +416,14 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit if run_tfl_consistency_test: self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) - tfl_process_args = process_args.copy() if 'inputs_as_nchw' in tfl_process_args: nchw_inps_with_port = tfl_process_args['inputs_as_nchw'] tfl_process_args['inputs_as_nchw'] = [i.split(':')[0] for i in nchw_inps_with_port] input_names_without_port = [inp.split(':')[0] for inp in feed_dict.keys()] - + if 'outputs_as_nchw' in tfl_process_args: + nchw_outps_with_port = tfl_process_args['outputs_as_nchw'] + tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port] + output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port] g = process_tf_graph(None, opset=self.config.opset, input_names=input_names_without_port, output_names=tfl_outputs, @@ -427,6 +435,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()} onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite", use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + onnx_tfl_res[i] = np.transpose(onnx_tfl_res[i], constants.NCHW_TO_NHWC) self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype) self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker) @@ -456,6 +468,10 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit g = optimizer.optimize_graph(g) onnx_tfjs_res = self.run_backend(g, None, onnx_feed_dict, large_model, postfix="_from_tfjs", use_custom_ops=use_custom_ops) + if 'outputs_as_nchw' in tfl_process_args: + for output_name in tfl_process_args['outputs_as_nchw']: + i = output_names_with_port.index(output_name) + onnx_tfjs_res[i] = np.transpose(onnx_tfjs_res[i], constants.NCHW_TO_NHWC) self.assert_results_equal(tfjs_res, onnx_tfjs_res, rtol, atol, mtol, check_value, check_shape, check_dtype=False) diff --git a/tests/test_backend.py b/tests/test_backend.py index 5d8973467..73f02ed38 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -738,7 +738,8 @@ def func(x): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, process_args={"inputs_as_nchw": [_INPUT], "outputs_as_nchw": [_OUTPUT]}, - onnx_feed_dict={_INPUT: x_val_for_onnx}) + onnx_feed_dict={_INPUT: x_val_for_onnx}, + graph_validator=lambda g: check_op_count(g, "Transpose", 0, disabled=False)) @skip_tflite("TFlite adds ops that obscure pattern") @check_tf_min_version("1.15") @@ -4584,7 +4585,7 @@ def test_add2(self): def func(x): x_ = tf.add(x, x) return tf.identity(x_, name=_TFOUTPUT) - self._run_test_case(func, [_2TPUT], {_INPUT: x_val}) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}) @check_opset_min_version(11, "CumSum") def test_cumsum(self): From e1eed5913075580d4f3ad00a41dfd74adac9cef9 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Wed, 6 Jul 2022 17:51:16 +0800 Subject: [PATCH 06/10] add it into function and doc Signed-off-by: Deyu Huang --- README.md | 20 ++++++++++++-------- tf2onnx/convert.py | 38 ++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index ad640e21e..aa6aad20b 100644 --- a/README.md +++ b/README.md @@ -292,8 +292,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None shape_override=None, - target=None, large_model=False, output_path=None) + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None) Args: model: the tf.keras model we want to convert @@ -307,7 +307,8 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose inputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -323,8 +324,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, input_signature=None, opset=None, custom_ops=None, - custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, shape_override=None, + custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, + outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None) Args: @@ -340,6 +341,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc + outputs_as_nchw: transpose inputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -354,7 +356,7 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None) @@ -370,6 +372,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc + outputs_as_nchw: transpose inputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -383,8 +386,8 @@ import tf2onnx model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None): Args: tflite_path: the tflite model file full path @@ -397,6 +400,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters inputs_as_nchw: transpose inputs in list from nchw to nhwc + outputs_as_nchw: transpose inputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow target: list of workarounds applied to help certain platforms diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index eb37a89b8..32c28f0bc 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -360,8 +360,8 @@ def _is_legacy_keras_model(model): def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, - inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, + target=None, large_model=False, output_path=None): """from_keras for tf 1.15""" input_names = [t.name for t in model.inputs] output_names = [t.name for t in model.outputs] @@ -396,6 +396,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -405,7 +406,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, output_path=None, optimizers=None): """Returns a ONNX model_proto for a tf.keras model. @@ -421,7 +422,8 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path optimizers: list (subset) of tf2onnx optimizers if applying all optimizers is not desired. @@ -431,7 +433,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ """ if LooseVersion(tf.__version__) < "2.0": return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, - extra_opset, shape_override, target, large_model, output_path) + outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path) old_out_names = _rename_duplicate_keras_model_names(model) from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel @@ -504,6 +506,7 @@ def wrap_call(*args, training=False, **kwargs): input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -513,8 +516,8 @@ def wrap_call(*args, training=False, **kwargs): def from_function(function, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, + shape_override=None, target=None, large_model=False, output_path=None): """Returns a ONNX model_proto for a tf.function. Args: @@ -529,7 +532,8 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -568,6 +572,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -577,8 +582,9 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c def from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, - custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, - shape_override=None, target=None, large_model=False, tensors_to_rename=None, output_path=None): + custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, + extra_opset=None, shape_override=None, target=None, large_model=False, + tensors_to_rename=None, output_path=None): """Returns a ONNX model_proto for a tensorflow graphdef. Args: @@ -595,7 +601,8 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -632,6 +639,7 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, @@ -641,8 +649,8 @@ def from_graph_def(graph_def, name=None, input_names=None, output_names=None, op def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, - custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, - large_model=False, output_path=None): + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, + target=None, large_model=False, output_path=None): """Returns a ONNX model_proto for a tflite model file. Args: @@ -655,7 +663,8 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu runtime can still open the model. Type is a dictionary `{op name: domain}`. custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow target: list of workarounds applied to help certain platforms @@ -684,6 +693,7 @@ def from_tflite(tflite_path, input_names=None, output_names=None, opset=None, cu input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, + outputs_as_nchw=outputs_as_nchw, large_model=large_model, tensors_to_rename=None, initialized_tables=None, From b20234ea5e6c658b6f1d25afdcd473aa66671db8 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Wed, 6 Jul 2022 20:39:09 +0800 Subject: [PATCH 07/10] fix output_names_with_port range Signed-off-by: Deyu Huang --- tests/backend_test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/backend_test_base.py b/tests/backend_test_base.py index dd7bf66ce..38cc52dcf 100644 --- a/tests/backend_test_base.py +++ b/tests/backend_test_base.py @@ -423,7 +423,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit if 'outputs_as_nchw' in tfl_process_args: nchw_outps_with_port = tfl_process_args['outputs_as_nchw'] tfl_process_args['outputs_as_nchw'] = [i.split(':')[0] for i in nchw_outps_with_port] - output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port] + output_names_with_port = [i.split(':')[0] for i in nchw_outps_with_port] g = process_tf_graph(None, opset=self.config.opset, input_names=input_names_without_port, output_names=tfl_outputs, From 464fffa8f02db0098e488c208666cf0dff581d83 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Thu, 7 Jul 2022 11:26:21 +0800 Subject: [PATCH 08/10] fix the input_as_nchw description Signed-off-by: Deyu Huang --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index aa6aad20b..c8b92e4d5 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw outputs_as_nchw: transpose inputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -371,7 +371,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw outputs_as_nchw: transpose inputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -399,7 +399,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, runtime can still open the model. Type is a dictionary `{op name: domain}`. custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters - inputs_as_nchw: transpose inputs in list from nchw to nhwc + inputs_as_nchw: transpose inputs in list from nhwc to nchw outputs_as_nchw: transpose inputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow From bc21b0378ff6b4625f3b164bb2156cddd4be5f1d Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Thu, 7 Jul 2022 14:21:35 +0800 Subject: [PATCH 09/10] fix typo Signed-off-by: Deyu Huang --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c8b92e4d5..bd4753e6b 100644 --- a/README.md +++ b/README.md @@ -308,7 +308,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nhwc to nchw - outputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -341,7 +341,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_function(function, extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nhwc to nchw - outputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -372,7 +372,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_graph_def(graph_def, extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nhwc to nchw - outputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw large_model: use the ONNX external tensor storage format output_path: save model to output_path @@ -400,7 +400,7 @@ model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(tflite_path, custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters inputs_as_nchw: transpose inputs in list from nhwc to nchw - outputs_as_nchw: transpose inputs in list from nhwc to nchw + outputs_as_nchw: transpose outputs in list from nhwc to nchw extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow target: list of workarounds applied to help certain platforms From 50cf1701edd699d8b4d75d155ef478bc095e30e4 Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Fri, 8 Jul 2022 14:45:53 +0800 Subject: [PATCH 10/10] change tests name Signed-off-by: Deyu Huang --- tests/test_backend.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_backend.py b/tests/test_backend.py index 73f02ed38..fe50e0591 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -712,7 +712,7 @@ def func(x): graph_validator=lambda g: (check_op_count(g, "RandomUniform", 0) and check_op_count(g, "RandomUniformLike", 0))) - def test_conv2d_with_input_transpose(self): + def test_inputs_as_nchw_arg(self): x_shape = [2, 32, 32, 3] kernel_shape = [3, 3, 3, 3] x_val = make_xval(x_shape) @@ -725,21 +725,16 @@ def func(x): process_args={"inputs_as_nchw": [_INPUT]}, onnx_feed_dict={_INPUT: x_val_for_onnx}) - def test_conv2d_with_output_transpose(self): + def test_outputs_as_nchw_arg(self): x_shape = [2, 32, 32, 3] kernel_shape = [3, 3, 3, 3] x_val = make_xval(x_shape) - x_val_for_onnx = x_val.transpose(NHWC_TO_NCHW) - def func(x): kernel = tf.constant(make_xval(kernel_shape), dtype=tf.float32, name='kernel') conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding="SAME") return tf.identity(conv, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, - process_args={"inputs_as_nchw": [_INPUT], - "outputs_as_nchw": [_OUTPUT]}, - onnx_feed_dict={_INPUT: x_val_for_onnx}, - graph_validator=lambda g: check_op_count(g, "Transpose", 0, disabled=False)) + process_args={"outputs_as_nchw": [_OUTPUT]}) @skip_tflite("TFlite adds ops that obscure pattern") @check_tf_min_version("1.15")