From dd45f8ea783e5065e9e7351c0caf82b6d9460601 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Mon, 26 Apr 2021 21:25:55 -0400 Subject: [PATCH 1/2] Fix bug that renamed subgraph i/o twice Signed-off-by: Tom Wildenhain --- tf2onnx/tfonnx.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index 7ec35e6fd..d4a3a442d 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -511,8 +511,6 @@ def rename_tensors_in_nodes(onnx_nodes): for func in ordered_func: f_inputs_names = [t.name for t in func.inputs] f_output_names = [t.name for t in func.outputs] - f_inputs_names = rename_tensors_in_list(f_inputs_names) - f_output_names = rename_tensors_in_list(f_output_names) fg = process_tf_graph(func, continue_on_error, False, target, opset, custom_op_handlers, custom_rewriter, extra_opset, shape_override, inputs_as_nchw, From e350e7efac259bea0b0a3370e02641d735a8209c Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 27 Apr 2021 15:49:38 -0400 Subject: [PATCH 2/2] Don't rename tensors in subgraphs at all Signed-off-by: Tom Wildenhain --- tf2onnx/tfonnx.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index d4a3a442d..ccea803db 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -522,12 +522,13 @@ def rename_tensors_in_nodes(onnx_nodes): check_io(input_names, output_names, output_shapes) - rename_tensors_in_nodes(onnx_nodes) - input_names = rename_tensors_in_list(input_names) - output_names = rename_tensors_in_list(output_names) - output_shapes = rename_tensors_in_dict(output_shapes) - dtypes = rename_tensors_in_dict(dtypes) - inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) + if not is_subgraph: + rename_tensors_in_nodes(onnx_nodes) + input_names = rename_tensors_in_list(input_names) + output_names = rename_tensors_in_list(output_names) + output_shapes = rename_tensors_in_dict(output_shapes) + dtypes = rename_tensors_in_dict(dtypes) + inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset, extra_opset, input_names, output_names, is_subgraph) g = process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_error, custom_rewriter, target, output_names, initialized_tables, outputs_to_values, outputs_to_dtypes, op_cnt, attr_cnt)