From e9c8ac006e55787152b9cbb53195aaa26e6bf513 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 25 May 2021 17:54:51 -0400 Subject: [PATCH 1/5] Remove captured inputs when using CLI Signed-off-by: Tom Wildenhain --- tf2onnx/tf_loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 9f8778885..393058852 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -498,6 +498,9 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d structured_inputs = set(inp + ":0" for inp in structured_inputs) if any(inp in structured_inputs for inp in inputs): inputs = [inp for inp in inputs if inp in structured_inputs] + graph_captures = concrete_func.graph._captures # pylint: disable=protected-access + captured_inputs = [t_name.name for _, t_name in graph_captures.values()] + inputs = [inp for inp in inputs if inp in captured_inputs] else: inputs = input_names From c442ea6577d4f3ae87921620cc726f58b74bfc5a Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 25 May 2021 18:16:33 -0400 Subject: [PATCH 2/5] Switch to use structured input names by default Signed-off-by: Tom Wildenhain --- tf2onnx/convert.py | 4 +++- tf2onnx/tf_loader.py | 39 ++++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 2afacd94b..df50c8282 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -65,6 +65,7 @@ def get_args(): "change into Identity ops using their default value") parser.add_argument("--rename-inputs", help="input names to use in final model (optional)") parser.add_argument("--rename-outputs", help="output names to use in final model (optional)") + parser.add_argument("--use-graph-names", help="skip renaming io using concrete names if loading a saved model") parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain") parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.", action="store_true") @@ -212,7 +213,8 @@ def main(): if args.saved_model: graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model( args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function, - args.large_model, return_initialized_tables=True, return_tensors_to_rename=True) + args.large_model, return_initialized_tables=True, return_tensors_to_rename=True, + use_graph_names=args.use_graph_names) model_path = args.saved_model if args.keras: graph_def, inputs, outputs = tf_loader.from_keras( diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index 393058852..fa8747d12 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -310,7 +310,7 @@ def from_checkpoint(model_path, input_names, output_names): return frozen_graph, input_names, output_names -def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names): +def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names, use_graph_names): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" @@ -345,14 +345,16 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa # TF1.12 changed the api get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k] + tensors_to_rename = {} if input_names is None: input_names = [] for k in signatures: inputs_tensor_info = get_signature_def(imported, k).inputs - for _, input_tensor in inputs_tensor_info.items(): + for structured_name, input_tensor in inputs_tensor_info.items(): if input_tensor.name not in input_names: input_names.append(input_tensor.name) - tensors_to_rename = {} + if not use_graph_names: + tensors_to_rename[input_tensor.name] = structured_name if output_names is None: output_names = [] for k in signatures: @@ -360,7 +362,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa for structured_name, output_tensor in outputs_tensor_info.items(): if output_tensor.name not in output_names: output_names.append(output_tensor.name) - tensors_to_rename[output_tensor.name] = structured_name + if not use_graph_names: + tensors_to_rename[output_tensor.name] = structured_name frozen_graph, initialized_tables = \ freeze_session(sess, input_names=input_names, output_names=output_names, get_tables=True) return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename @@ -447,7 +450,7 @@ def _restore_captured_resources(concrete_func, graph_captures_copy, func_capture def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, - concrete_function_index, large_model): + concrete_function_index, large_model, use_graph_names): """Load tensorflow graph from saved_model.""" wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve" @@ -492,21 +495,19 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d tensors_to_rename = {} if input_names is None: inputs = [tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource] - if concrete_func.structured_input_signature is not None: - args, kwargs = concrete_func.structured_input_signature - structured_inputs = [t.name for t in args if isinstance(t, tf.TensorSpec)] + sorted(kwargs.keys()) - structured_inputs = set(inp + ":0" for inp in structured_inputs) - if any(inp in structured_inputs for inp in inputs): - inputs = [inp for inp in inputs if inp in structured_inputs] - graph_captures = concrete_func.graph._captures # pylint: disable=protected-access - captured_inputs = [t_name.name for _, t_name in graph_captures.values()] - inputs = [inp for inp in inputs if inp in captured_inputs] + graph_captures = concrete_func.graph._captures # pylint: disable=protected-access + captured_inputs = [t_name.name for _, t_name in graph_captures.values()] + inputs = [inp for inp in inputs if inp in captured_inputs] + if concrete_func.structured_input_signature is not None and not use_graph_names: + flat_structured_inp = tf.nest.flatten(concrete_func.structured_input_signature) + structured_inputs = [t.name for t in flat_structured_inp if isinstance(t, tf.TensorSpec)] + tensors_to_rename.update(zip(input_names, structured_inputs)) else: inputs = input_names if output_names is None: outputs = [tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource] - if isinstance(concrete_func.structured_outputs, dict): + if isinstance(concrete_func.structured_outputs, dict) and not use_graph_names: # outputs are sorted, sort structured_outputs the same way structured_outputs = sorted(concrete_func.structured_outputs.keys()) tensors_to_rename.update(zip(outputs, structured_outputs)) @@ -515,7 +516,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d logger.info("Output names: %r", outputs) else: outputs = output_names - logger.info("Outputs not left as None; will use provided names not structured output names.") frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model) @@ -524,7 +524,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None, large_model=False, - return_concrete_func=False, return_initialized_tables=False, return_tensors_to_rename=False): + return_concrete_func=False, return_initialized_tables=False, + return_tensors_to_rename=False, use_graph_names=True): """Load tensorflow graph from saved_model.""" if signatures is None: signatures = [] @@ -533,7 +534,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None, if is_tf2(): frozen_graph, input_names, output_names, concrete_func, imported, initialized_tables, tensors_to_rename = \ _from_saved_model_v2(model_path, input_names, output_names, - tag, signatures, concrete_function, large_model) + tag, signatures, concrete_function, large_model, use_graph_names) result = [frozen_graph, input_names, output_names] if return_concrete_func: result += [concrete_func, imported] @@ -544,7 +545,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None, else: with tf_session() as sess: frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename = \ - _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures) + _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures, use_graph_names) result = [frozen_graph, input_names, output_names] if return_initialized_tables: result += [initialized_tables] From cd5e37662780e383940d14a6487e80e08fc3e8ff Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 25 May 2021 19:13:09 -0400 Subject: [PATCH 3/5] Bugfixes Signed-off-by: Tom Wildenhain --- tests/run_pretrained_models.py | 6 ++++-- tf2onnx/convert.py | 3 ++- tf2onnx/tf_loader.py | 6 +++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/run_pretrained_models.py b/tests/run_pretrained_models.py index 7d05afb04..068c04d5d 100644 --- a/tests/run_pretrained_models.py +++ b/tests/run_pretrained_models.py @@ -438,7 +438,7 @@ def run_tflite(): inputs = {} for k in input_names: v = self.input_names[k] - inputs[k.split(":")[0]] = tf.constant(self.make_input(v)) + inputs[to_rename[k]] = tf.constant(self.make_input(v)) tf_func = tf.function(concrete_func) logger.info("Running TF") tf_results_d = tf_func(**inputs) @@ -557,7 +557,9 @@ def run_tflite(): struc_outputs = self.output_names else: struc_outputs = [to_rename.get(k, k) for k in self.output_names] - onnx_results = self.run_onnxruntime(name, model_proto, inputs, struc_outputs, external_tensor_storage) + struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()} + onnx_results = self.run_onnxruntime( + name, model_proto, struc_inputs, struc_outputs, external_tensor_storage) else: raise ValueError("unknown backend") logger.info("Run_ONNX OK") diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index df50c8282..2401dc3f6 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -65,7 +65,8 @@ def get_args(): "change into Identity ops using their default value") parser.add_argument("--rename-inputs", help="input names to use in final model (optional)") parser.add_argument("--rename-outputs", help="output names to use in final model (optional)") - parser.add_argument("--use-graph-names", help="skip renaming io using concrete names if loading a saved model") + parser.add_argument("--use-graph-names", help="(saved model only) skip renaming io using signature names", + action="store_true") parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain") parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.", action="store_true") diff --git a/tf2onnx/tf_loader.py b/tf2onnx/tf_loader.py index fa8747d12..60b701830 100644 --- a/tf2onnx/tf_loader.py +++ b/tf2onnx/tf_loader.py @@ -497,11 +497,11 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d inputs = [tensor.name for tensor in concrete_func.inputs if tensor.dtype != tf.dtypes.resource] graph_captures = concrete_func.graph._captures # pylint: disable=protected-access captured_inputs = [t_name.name for _, t_name in graph_captures.values()] - inputs = [inp for inp in inputs if inp in captured_inputs] + inputs = [inp for inp in inputs if inp not in captured_inputs] if concrete_func.structured_input_signature is not None and not use_graph_names: flat_structured_inp = tf.nest.flatten(concrete_func.structured_input_signature) structured_inputs = [t.name for t in flat_structured_inp if isinstance(t, tf.TensorSpec)] - tensors_to_rename.update(zip(input_names, structured_inputs)) + tensors_to_rename.update(zip(inputs, structured_inputs)) else: inputs = input_names @@ -525,7 +525,7 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d def from_saved_model(model_path, input_names, output_names, tag=None, signatures=None, concrete_function=None, large_model=False, return_concrete_func=False, return_initialized_tables=False, - return_tensors_to_rename=False, use_graph_names=True): + return_tensors_to_rename=False, use_graph_names=False): """Load tensorflow graph from saved_model.""" if signatures is None: signatures = [] From bb33a1f3ad91f9bb646418147f9610f17839cf5f Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 25 May 2021 21:13:24 -0400 Subject: [PATCH 4/5] Update tutorials Signed-off-by: Tom Wildenhain --- examples/benchmark_tfmodel_ort.py | 2 +- examples/end2end_tfhub.py | 4 ++-- examples/end2end_tfkeras.py | 4 ++-- examples/getting_started.py | 2 +- tests/run_pretrained_models.py | 12 ++++++------ 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/benchmark_tfmodel_ort.py b/examples/benchmark_tfmodel_ort.py index ac1f7c3c1..363d97f8d 100644 --- a/examples/benchmark_tfmodel_ort.py +++ b/examples/benchmark_tfmodel_ort.py @@ -38,7 +38,7 @@ def measure_time(fct, imgs): # Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1 # python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12 ort = ort.InferenceSession('esrgan-tf2.onnx') -fct_ort = lambda img: ort.run(None, {'input_0:0': img}) +fct_ort = lambda img: ort.run(None, {'input_0': img}) results_ort, duration_ort = measure_time(fct_ort, imgs) print(len(imgs), duration_ort) diff --git a/examples/end2end_tfhub.py b/examples/end2end_tfhub.py index d61a9a595..25e5251a4 100644 --- a/examples/end2end_tfhub.py +++ b/examples/end2end_tfhub.py @@ -62,7 +62,7 @@ ######################################## # Runs onnxruntime. session = InferenceSession("efficientnetb0clas.onnx") -got = session.run(None, {'input_1:0': input}) +got = session.run(None, {'input_1': input}) print(got[0]) ######################################## @@ -73,5 +73,5 @@ # Measures processing time. print('tf:', timeit.timeit('model.predict(input)', number=10, globals=globals())) -print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", +print('ort:', timeit.timeit("session.run(None, {'input_1': input})", number=10, globals=globals())) diff --git a/examples/end2end_tfkeras.py b/examples/end2end_tfkeras.py index f25455cf6..19da4d3b3 100644 --- a/examples/end2end_tfkeras.py +++ b/examples/end2end_tfkeras.py @@ -57,7 +57,7 @@ ######################################## # Runs onnxruntime. session = InferenceSession("simple_rnn.onnx") -got = session.run(None, {'input_1:0': input}) +got = session.run(None, {'input_1': input}) print(got[0]) ######################################## @@ -68,5 +68,5 @@ # Measures processing time. print('tf:', timeit.timeit('model.predict(input)', number=100, globals=globals())) -print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})", +print('ort:', timeit.timeit("session.run(None, {'input_1': input})", number=100, globals=globals())) diff --git a/examples/getting_started.py b/examples/getting_started.py index 5118fd3f9..d00ea0334 100644 --- a/examples/getting_started.py +++ b/examples/getting_started.py @@ -58,7 +58,7 @@ def f(a, b): print("ORT result") sess = ort.InferenceSession("model.onnx") -res = sess.run(None, {'dense_input:0': x_val}) +res = sess.run(None, {'dense_input': x_val}) print(res[0]) print("Conversion succeeded") \ No newline at end of file diff --git a/tests/run_pretrained_models.py b/tests/run_pretrained_models.py index 068c04d5d..4d9d97933 100644 --- a/tests/run_pretrained_models.py +++ b/tests/run_pretrained_models.py @@ -375,7 +375,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr initialized_tables = {} outputs = self.output_names tflite_path = None - to_rename = None + to_rename = {} if self.model_type in ["checkpoint"]: graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs) elif self.model_type in ["saved_model"]: @@ -400,6 +400,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr if utils.is_debug_mode(): utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def) + logger.info("Input names %s", input_names) if tflite_path is not None: inputs = {} for k in input_names: @@ -438,7 +439,7 @@ def run_tflite(): inputs = {} for k in input_names: v = self.input_names[k] - inputs[to_rename[k]] = tf.constant(self.make_input(v)) + inputs[to_rename.get(k, k)] = tf.constant(self.make_input(v)) tf_func = tf.function(concrete_func) logger.info("Running TF") tf_results_d = tf_func(**inputs) @@ -507,6 +508,7 @@ def run_tflite(): elif self.run_tf_frozen: if self.tf_profile is not None: tf.profiler.experimental.start(self.tf_profile) + logger.info("TF inputs %s", list(inputs.keys())) tf_results = self.run_tensorflow(sess, inputs) if self.tf_profile is not None: tf.profiler.experimental.stop() @@ -553,11 +555,9 @@ def run_tflite(): try: onnx_results = None if backend == "onnxruntime": - if to_rename is None: - struc_outputs = self.output_names - else: - struc_outputs = [to_rename.get(k, k) for k in self.output_names] + struc_outputs = [to_rename.get(k, k) for k in self.output_names] struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()} + logger.info("ORT inputs %s", list(struc_inputs.keys())) onnx_results = self.run_onnxruntime( name, model_proto, struc_inputs, struc_outputs, external_tensor_storage) else: From 7247a1d26b9daed0dff9d8a933116789174a18d6 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Wed, 26 May 2021 13:20:21 -0400 Subject: [PATCH 5/5] Remove logging lines Signed-off-by: Tom Wildenhain --- tests/run_pretrained_models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/run_pretrained_models.py b/tests/run_pretrained_models.py index 4d9d97933..ae1254ded 100644 --- a/tests/run_pretrained_models.py +++ b/tests/run_pretrained_models.py @@ -400,7 +400,6 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr if utils.is_debug_mode(): utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def) - logger.info("Input names %s", input_names) if tflite_path is not None: inputs = {} for k in input_names: @@ -508,7 +507,6 @@ def run_tflite(): elif self.run_tf_frozen: if self.tf_profile is not None: tf.profiler.experimental.start(self.tf_profile) - logger.info("TF inputs %s", list(inputs.keys())) tf_results = self.run_tensorflow(sess, inputs) if self.tf_profile is not None: tf.profiler.experimental.stop() @@ -557,7 +555,6 @@ def run_tflite(): if backend == "onnxruntime": struc_outputs = [to_rename.get(k, k) for k in self.output_names] struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()} - logger.info("ORT inputs %s", list(struc_inputs.keys())) onnx_results = self.run_onnxruntime( name, model_proto, struc_inputs, struc_outputs, external_tensor_storage) else: