Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/benchmark_tfmodel_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def measure_time(fct, imgs):
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
ort = ort.InferenceSession('esrgan-tf2.onnx')
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
fct_ort = lambda img: ort.run(None, {'input_0': img})
results_ort, duration_ort = measure_time(fct_ort, imgs)
print(len(imgs), duration_ort)

Expand Down
4 changes: 2 additions & 2 deletions examples/end2end_tfhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
########################################
# Runs onnxruntime.
session = InferenceSession("efficientnetb0clas.onnx")
got = session.run(None, {'input_1:0': input})
got = session.run(None, {'input_1': input})
print(got[0])

########################################
Expand All @@ -73,5 +73,5 @@
# Measures processing time.
print('tf:', timeit.timeit('model.predict(input)',
number=10, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
number=10, globals=globals()))
4 changes: 2 additions & 2 deletions examples/end2end_tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
########################################
# Runs onnxruntime.
session = InferenceSession("simple_rnn.onnx")
got = session.run(None, {'input_1:0': input})
got = session.run(None, {'input_1': input})
print(got[0])

########################################
Expand All @@ -68,5 +68,5 @@
# Measures processing time.
print('tf:', timeit.timeit('model.predict(input)',
number=100, globals=globals()))
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
number=100, globals=globals()))
2 changes: 1 addition & 1 deletion examples/getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def f(a, b):

print("ORT result")
sess = ort.InferenceSession("model.onnx")
res = sess.run(None, {'dense_input:0': x_val})
res = sess.run(None, {'dense_input': x_val})
print(res[0])

print("Conversion succeeded")
13 changes: 6 additions & 7 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
initialized_tables = {}
outputs = self.output_names
tflite_path = None
to_rename = None
to_rename = {}
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
Expand Down Expand Up @@ -438,7 +438,7 @@ def run_tflite():
inputs = {}
for k in input_names:
v = self.input_names[k]
inputs[k.split(":")[0]] = tf.constant(self.make_input(v))
inputs[to_rename.get(k, k)] = tf.constant(self.make_input(v))
tf_func = tf.function(concrete_func)
logger.info("Running TF")
tf_results_d = tf_func(**inputs)
Expand Down Expand Up @@ -553,11 +553,10 @@ def run_tflite():
try:
onnx_results = None
if backend == "onnxruntime":
if to_rename is None:
struc_outputs = self.output_names
else:
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
onnx_results = self.run_onnxruntime(name, model_proto, inputs, struc_outputs, external_tensor_storage)
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()}
onnx_results = self.run_onnxruntime(
name, model_proto, struc_inputs, struc_outputs, external_tensor_storage)
else:
raise ValueError("unknown backend")
logger.info("Run_ONNX OK")
Expand Down
5 changes: 4 additions & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def get_args():
"change into Identity ops using their default value")
parser.add_argument("--rename-inputs", help="input names to use in final model (optional)")
parser.add_argument("--rename-outputs", help="output names to use in final model (optional)")
parser.add_argument("--use-graph-names", help="(saved model only) skip renaming io using signature names",
action="store_true")
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
action="store_true")
Expand Down Expand Up @@ -212,7 +214,8 @@ def main():
if args.saved_model:
graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model(
args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function,
args.large_model, return_initialized_tables=True, return_tensors_to_rename=True)
args.large_model, return_initialized_tables=True, return_tensors_to_rename=True,
use_graph_names=args.use_graph_names)
model_path = args.saved_model
if args.keras:
graph_def, inputs, outputs = tf_loader.from_keras(
Expand Down
33 changes: 17 additions & 16 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def from_checkpoint(model_path, input_names, output_names):
return frozen_graph, input_names, output_names


def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names):
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signature_names, use_graph_names):
"""Load tensorflow graph from saved_model."""

wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
Expand Down Expand Up @@ -345,22 +345,25 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
# TF1.12 changed the api
get_signature_def = lambda meta_graph_def, k: meta_graph_def.signature_def[k]

tensors_to_rename = {}
if input_names is None:
input_names = []
for k in signatures:
inputs_tensor_info = get_signature_def(imported, k).inputs
for _, input_tensor in inputs_tensor_info.items():
for structured_name, input_tensor in inputs_tensor_info.items():
if input_tensor.name not in input_names:
input_names.append(input_tensor.name)
tensors_to_rename = {}
if not use_graph_names:
tensors_to_rename[input_tensor.name] = structured_name
if output_names is None:
output_names = []
for k in signatures:
outputs_tensor_info = get_signature_def(imported, k).outputs
for structured_name, output_tensor in outputs_tensor_info.items():
if output_tensor.name not in output_names:
output_names.append(output_tensor.name)
tensors_to_rename[output_tensor.name] = structured_name
if not use_graph_names:
tensors_to_rename[output_tensor.name] = structured_name
frozen_graph, initialized_tables = \
freeze_session(sess, input_names=input_names, output_names=output_names, get_tables=True)
return frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename
Expand Down Expand Up @@ -447,7 +450,7 @@ def _restore_captured_resources(concrete_func, graph_captures_copy, func_capture


def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def,
concrete_function_index, large_model):
concrete_function_index, large_model, use_graph_names):
"""Load tensorflow graph from saved_model."""

wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
Expand Down Expand Up @@ -495,18 +498,16 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
inputs = [inp for inp in inputs if inp not in captured_inputs]
if concrete_func.structured_input_signature is not None:
args, kwargs = concrete_func.structured_input_signature
structured_inputs = [t.name for t in args if isinstance(t, tf.TensorSpec)] + sorted(kwargs.keys())
structured_inputs = set(inp + ":0" for inp in structured_inputs)
if any(inp in structured_inputs for inp in inputs):
inputs = [inp for inp in inputs if inp in structured_inputs]
if concrete_func.structured_input_signature is not None and not use_graph_names:
flat_structured_inp = tf.nest.flatten(concrete_func.structured_input_signature)
structured_inputs = [t.name for t in flat_structured_inp if isinstance(t, tf.TensorSpec)]
tensors_to_rename.update(zip(inputs, structured_inputs))
else:
inputs = input_names

if output_names is None:
outputs = [tensor.name for tensor in concrete_func.outputs if tensor.dtype != tf.dtypes.resource]
if isinstance(concrete_func.structured_outputs, dict):
if isinstance(concrete_func.structured_outputs, dict) and not use_graph_names:
# outputs are sorted, sort structured_outputs the same way
structured_outputs = sorted(concrete_func.structured_outputs.keys())
tensors_to_rename.update(zip(outputs, structured_outputs))
Expand All @@ -515,7 +516,6 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
logger.info("Output names: %r", outputs)
else:
outputs = output_names
logger.info("Outputs not left as None; will use provided names not structured output names.")

frozen_graph, initialized_tables = from_trackable(imported, concrete_func, inputs, outputs, large_model)

Expand All @@ -524,7 +524,8 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d

def from_saved_model(model_path, input_names, output_names, tag=None,
signatures=None, concrete_function=None, large_model=False,
return_concrete_func=False, return_initialized_tables=False, return_tensors_to_rename=False):
return_concrete_func=False, return_initialized_tables=False,
return_tensors_to_rename=False, use_graph_names=False):
"""Load tensorflow graph from saved_model."""
if signatures is None:
signatures = []
Expand All @@ -533,7 +534,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
if is_tf2():
frozen_graph, input_names, output_names, concrete_func, imported, initialized_tables, tensors_to_rename = \
_from_saved_model_v2(model_path, input_names, output_names,
tag, signatures, concrete_function, large_model)
tag, signatures, concrete_function, large_model, use_graph_names)
result = [frozen_graph, input_names, output_names]
if return_concrete_func:
result += [concrete_func, imported]
Expand All @@ -544,7 +545,7 @@ def from_saved_model(model_path, input_names, output_names, tag=None,
else:
with tf_session() as sess:
frozen_graph, input_names, output_names, initialized_tables, tensors_to_rename = \
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures)
_from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures, use_graph_names)
result = [frozen_graph, input_names, output_names]
if return_initialized_tables:
result += [initialized_tables]
Expand Down