Skip to content

Commit c54b00e

Browse files
Add tflite support to run_pretrained_models.py
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 43d21a7 commit c54b00e

File tree

2 files changed

+96
-43
lines changed

2 files changed

+96
-43
lines changed

tests/car.JPEG

29.9 KB
Loading

tests/run_pretrained_models.py

Lines changed: 96 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,27 @@
5656
PERFITER = 1000
5757

5858

59-
def get_beach(shape):
60-
"""Get beach image as input."""
59+
def get_img(shape, path, dtype, should_scale=True):
60+
"""Get image as input."""
6161
resize_to = shape[1:3]
62-
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "beach.jpg")
62+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
6363
img = PIL.Image.open(path)
6464
img = img.resize(resize_to, PIL.Image.ANTIALIAS)
65-
img_np = np.array(img).astype(np.float32)
65+
img_np = np.array(img).astype(dtype)
6666
img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape)
67-
return img_np / 255
67+
if should_scale:
68+
img_np = img_np / 255
69+
return img_np
70+
71+
72+
def get_beach(shape):
73+
"""Get beach image as input."""
74+
return get_img(shape, "beach.jpg", np.float32, should_scale=True)
75+
76+
77+
def get_car(shape):
78+
"""Get car image as input."""
79+
return get_img(shape, "car.JPEG", np.float32, should_scale=True)
6880

6981

7082
def get_random(shape):
@@ -133,6 +145,7 @@ def get_sentence():
133145

134146
_INPUT_FUNC_MAPPING = {
135147
"get_beach": get_beach,
148+
"get_car": get_car,
136149
"get_random": get_random,
137150
"get_random256": get_random256,
138151
"get_ramp": get_ramp,
@@ -219,6 +232,9 @@ def download_model(self):
219232
elif url.endswith('.zip'):
220233
ftype = 'zip'
221234
dir_name = fname.replace(".zip", "")
235+
elif url.endswith('.tflite'):
236+
ftype = 'tflite'
237+
dir_name = fname.replace(".tflite", "")
222238
dir_name = os.path.join(cache_dir, dir_name)
223239
os.makedirs(dir_name, exist_ok=True)
224240
fpath = os.path.join(dir_name, fname)
@@ -266,7 +282,7 @@ def run_tensorflow(self, sess, inputs):
266282
return result
267283

268284
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
269-
const_node_values=None, initialized_tables=None):
285+
const_node_values=None, initialized_tables=None, tflite_path=None):
270286
"""Convert graph to tensorflow."""
271287
if extra_opset is None:
272288
extra_opset = []
@@ -275,8 +291,8 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
275291
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
276292
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
277293
input_names=input_names, output_names=self.output_names,
278-
const_node_values=const_node_values,
279-
initialized_tables=initialized_tables)
294+
const_node_values=const_node_values, initialized_tables=initialized_tables,
295+
tflite_path=tflite_path)
280296

281297
def run_caffe2(self, name, model_proto, inputs):
282298
"""Run test again caffe2 backend."""
@@ -340,6 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
340356
input_names = list(self.input_names.keys())
341357
initialized_tables = {}
342358
outputs = self.output_names
359+
tflite_path = None
343360
if self.model_type in ["checkpoint"]:
344361
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
345362
elif self.model_type in ["saved_model"]:
@@ -355,12 +372,43 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
355372
graph_def, input_names, outputs, initialized_tables = loaded
356373
elif self.model_type in ["keras"]:
357374
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
375+
elif self.model_type in ["tflite"]:
376+
tflite_path = model_path
377+
graph_def = None
358378
else:
359379
graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs)
360380

361381
if utils.is_debug_mode():
362382
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
363383

384+
if tflite_path is not None:
385+
inputs = {}
386+
for k in input_names:
387+
v = self.input_names[k]
388+
inputs[k] = self.make_input(v)
389+
390+
interpreter = tf.lite.Interpreter(tflite_path)
391+
input_details = interpreter.get_input_details()
392+
output_details = interpreter.get_output_details()
393+
input_name_to_index = {n['name'].split(':')[0]: n['index'] for n in input_details}
394+
for k, v in inputs.items():
395+
interpreter.resize_tensor_input(input_name_to_index[k], v.shape)
396+
interpreter.allocate_tensors()
397+
def run_tflite():
398+
for k, v in inputs.items():
399+
interpreter.set_tensor(input_name_to_index[k], v)
400+
interpreter.invoke()
401+
result = [interpreter.get_tensor(output['index']) for output in output_details]
402+
return result
403+
tf_results = run_tflite()
404+
if self.perf:
405+
logger.info("Running TFLite perf")
406+
start = time.time()
407+
for _ in range(PERFITER):
408+
_ = run_tflite()
409+
self.tf_runtime = time.time() - start
410+
logger.info("TFLite OK")
411+
364412
if not self.run_tf_frozen:
365413
inputs = {}
366414
for k in input_names:
@@ -384,45 +432,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
384432
self.tf_runtime = time.time() - start
385433
logger.info("TensorFlow OK")
386434

387-
inputs = {}
388435
shape_override = {}
389-
tf_reset_default_graph()
390-
391-
from tf2onnx.tf_utils import compress_graph_def
392436
const_node_values = None
393-
with tf.Graph().as_default() as tf_graph:
394-
if self.large_model:
395-
const_node_values = compress_graph_def(graph_def)
396-
tf.import_graph_def(graph_def, name='')
437+
tf_graph = None
397438

398-
with tf_session(graph=tf_graph) as sess:
399-
# create the input data
400-
for k in input_names:
401-
v = self.input_names[k]
402-
t = sess.graph.get_tensor_by_name(k)
403-
expected_dtype = tf.as_dtype(t.dtype).name
404-
if isinstance(v, six.text_type) and v.startswith("np."):
405-
np_value = eval(v) # pylint: disable=eval-used
406-
if expected_dtype != np_value.dtype:
407-
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
408-
np_value.dtype)
409-
inputs[k] = np_value.astype(expected_dtype)
410-
else:
411-
if expected_dtype == "string":
412-
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
439+
if graph_def is not None:
440+
inputs = {}
441+
tf_reset_default_graph()
442+
443+
with tf.Graph().as_default() as tf_graph:
444+
from tf2onnx.tf_utils import compress_graph_def
445+
if self.large_model:
446+
const_node_values = compress_graph_def(graph_def)
447+
tf.import_graph_def(graph_def, name='')
448+
449+
with tf_session(graph=tf_graph) as sess:
450+
# create the input data
451+
for k in input_names:
452+
v = self.input_names[k]
453+
t = sess.graph.get_tensor_by_name(k)
454+
expected_dtype = tf.as_dtype(t.dtype).name
455+
if isinstance(v, six.text_type) and v.startswith("np."):
456+
np_value = eval(v) # pylint: disable=eval-used
457+
if expected_dtype != np_value.dtype:
458+
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
459+
np_value.dtype)
460+
inputs[k] = np_value.astype(expected_dtype)
413461
else:
414-
inputs[k] = self.make_input(v).astype(expected_dtype)
462+
if expected_dtype == "string":
463+
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
464+
else:
465+
inputs[k] = self.make_input(v).astype(expected_dtype)
415466

416-
if self.force_input_shape:
417-
for k, v in inputs.items():
418-
shape_override[k] = list(v.shape)
467+
if self.force_input_shape:
468+
for k, v in inputs.items():
469+
shape_override[k] = list(v.shape)
470+
471+
# run the model with tensorflow
472+
if self.skip_tensorflow:
473+
logger.info("TensorFlow SKIPPED")
474+
elif self.run_tf_frozen:
475+
tf_results = self.run_tensorflow(sess, inputs)
476+
logger.info("TensorFlow OK")
477+
tf_graph = sess.graph
419478

420-
# run the model with tensorflow
421-
if self.skip_tensorflow:
422-
logger.info("TensorFlow SKIPPED")
423-
elif self.run_tf_frozen:
424-
tf_results = self.run_tensorflow(sess, inputs)
425-
logger.info("TensorFlow OK")
426479

427480
model_proto = None
428481
if self.skip_conversion:
@@ -436,10 +489,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
436489
else:
437490
try:
438491
# convert model to onnx
439-
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
492+
onnx_graph = self.to_onnx(tf_graph, opset=opset, extra_opset=extra_opset,
440493
shape_override=shape_override, input_names=inputs.keys(),
441494
const_node_values=const_node_values,
442-
initialized_tables=initialized_tables)
495+
initialized_tables=initialized_tables, tflite_path=tflite_path)
443496
onnx_graph = optimizer.optimize_graph(onnx_graph)
444497
print("ONNX", onnx_graph.dump_node_statistics())
445498
external_tensor_storage = ExternalTensorStorage() if self.large_model else None

0 commit comments

Comments
 (0)