Skip to content

Commit bb33a1f

Browse files
Update tutorials
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent cd5e376 commit bb33a1f

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

examples/benchmark_tfmodel_ort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def measure_time(fct, imgs):
3838
# Download model from https://tfhub.dev/captain-pool/esrgan-tf2/1
3939
# python -m tf2onnx.convert --saved-model esrgan --output "esrgan-tf2.onnx" --opset 12
4040
ort = ort.InferenceSession('esrgan-tf2.onnx')
41-
fct_ort = lambda img: ort.run(None, {'input_0:0': img})
41+
fct_ort = lambda img: ort.run(None, {'input_0': img})
4242
results_ort, duration_ort = measure_time(fct_ort, imgs)
4343
print(len(imgs), duration_ort)
4444

examples/end2end_tfhub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
########################################
6363
# Runs onnxruntime.
6464
session = InferenceSession("efficientnetb0clas.onnx")
65-
got = session.run(None, {'input_1:0': input})
65+
got = session.run(None, {'input_1': input})
6666
print(got[0])
6767

6868
########################################
@@ -73,5 +73,5 @@
7373
# Measures processing time.
7474
print('tf:', timeit.timeit('model.predict(input)',
7575
number=10, globals=globals()))
76-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
76+
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
7777
number=10, globals=globals()))

examples/end2end_tfkeras.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
########################################
5858
# Runs onnxruntime.
5959
session = InferenceSession("simple_rnn.onnx")
60-
got = session.run(None, {'input_1:0': input})
60+
got = session.run(None, {'input_1': input})
6161
print(got[0])
6262

6363
########################################
@@ -68,5 +68,5 @@
6868
# Measures processing time.
6969
print('tf:', timeit.timeit('model.predict(input)',
7070
number=100, globals=globals()))
71-
print('ort:', timeit.timeit("session.run(None, {'input_1:0': input})",
71+
print('ort:', timeit.timeit("session.run(None, {'input_1': input})",
7272
number=100, globals=globals()))

examples/getting_started.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def f(a, b):
5858

5959
print("ORT result")
6060
sess = ort.InferenceSession("model.onnx")
61-
res = sess.run(None, {'dense_input:0': x_val})
61+
res = sess.run(None, {'dense_input': x_val})
6262
print(res[0])
6363

6464
print("Conversion succeeded")

tests/run_pretrained_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
375375
initialized_tables = {}
376376
outputs = self.output_names
377377
tflite_path = None
378-
to_rename = None
378+
to_rename = {}
379379
if self.model_type in ["checkpoint"]:
380380
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
381381
elif self.model_type in ["saved_model"]:
@@ -400,6 +400,7 @@ def run_test(self, name, backend="onnxruntime", onnx_file=None, opset=None, extr
400400
if utils.is_debug_mode():
401401
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
402402

403+
logger.info("Input names %s", input_names)
403404
if tflite_path is not None:
404405
inputs = {}
405406
for k in input_names:
@@ -438,7 +439,7 @@ def run_tflite():
438439
inputs = {}
439440
for k in input_names:
440441
v = self.input_names[k]
441-
inputs[to_rename[k]] = tf.constant(self.make_input(v))
442+
inputs[to_rename.get(k, k)] = tf.constant(self.make_input(v))
442443
tf_func = tf.function(concrete_func)
443444
logger.info("Running TF")
444445
tf_results_d = tf_func(**inputs)
@@ -507,6 +508,7 @@ def run_tflite():
507508
elif self.run_tf_frozen:
508509
if self.tf_profile is not None:
509510
tf.profiler.experimental.start(self.tf_profile)
511+
logger.info("TF inputs %s", list(inputs.keys()))
510512
tf_results = self.run_tensorflow(sess, inputs)
511513
if self.tf_profile is not None:
512514
tf.profiler.experimental.stop()
@@ -553,11 +555,9 @@ def run_tflite():
553555
try:
554556
onnx_results = None
555557
if backend == "onnxruntime":
556-
if to_rename is None:
557-
struc_outputs = self.output_names
558-
else:
559-
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
558+
struc_outputs = [to_rename.get(k, k) for k in self.output_names]
560559
struc_inputs = {to_rename.get(k, k): v for k, v in inputs.items()}
560+
logger.info("ORT inputs %s", list(struc_inputs.keys()))
561561
onnx_results = self.run_onnxruntime(
562562
name, model_proto, struc_inputs, struc_outputs, external_tensor_storage)
563563
else:

0 commit comments

Comments
 (0)