Skip to content

Commit a31772c

Browse files
committed
Add script for bert and mobilenet
1 parent 23ac300 commit a31772c

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ __pycache__
1313
.eggs
1414
*.egg-info
1515
run.sh
16+
tests/tfhub/*/*.onnx
17+
tests/tfhub/*/*.tar.gz
18+
tests/tfhub/*/*.tflite
19+
tests/tfhub/*/**

tests/tfhub/_tools.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@
1111
import zipfile
1212
import subprocess
1313
import datetime
14+
from collections import OrderedDict
1415
import numpy
1516
from tqdm import tqdm
1617
import onnxruntime
1718

1819

19-
def generate_random_images(shape=(1, 100, 100, 3), n=10, dtype=numpy.float32):
20+
def generate_random_images(shape=(1, 100, 100, 3), n=10, dtype=numpy.float32, scale=255):
2021
imgs = []
2122
for i in range(n):
2223
sh = shape
23-
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * 255
24+
img = numpy.clip(numpy.abs(numpy.random.randn(*sh)), 0, 1) * scale
2425
img = img.astype(dtype)
2526
imgs.append(img)
2627
return imgs
@@ -180,19 +181,27 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
180181
print(" {}: {}, {}".format(a.name, a.type, a.shape))
181182

182183
# onnxruntime
183-
input_name = ort.get_inputs()[0].name
184-
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
184+
if isinstance(imgs[0], dict):
185+
fct_ort = lambda img: ort.run(None, img)[0]
186+
else:
187+
input_name = ort.get_inputs()[0].name
188+
fct_ort = lambda img: ort.run(None, {input_name: img})[0]
185189
results_ort, duration_ort = measure_time(fct_ort, imgs)
186190
if verbose:
187191
print("ORT", len(imgs), duration_ort)
188192

189193
# tensorflow
190194
import tensorflow_hub as hub
191195
from tensorflow import convert_to_tensor
196+
if isinstance(imgs[0], OrderedDict):
197+
imgs_tf = [
198+
OrderedDict((k, convert_to_tensor(v)) for k, v in img.items())
199+
for img in imgs]
200+
else:
201+
imgs_tf = [convert_to_tensor(img) for img in imgs]
192202
model = hub.load(url.split("?")[0])
193203
if signature is not None:
194204
model = model.signatures['serving_default']
195-
imgs_tf = [convert_to_tensor(img) for img in imgs]
196205
results_tf, duration_tf = measure_time(model, imgs_tf)
197206

198207
if verbose:
@@ -205,7 +214,9 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
205214
res = model(imgs_tf[0])
206215
if isinstance(res, dict):
207216
if len(res) != 1:
208-
raise NotImplementedError("TF output contains more than one output: %r." % res)
217+
raise NotImplementedError(
218+
"TF output contains more than one output=%r and output names=%r." % (
219+
res, [o.name for o in ort.get_outputs()]))
209220
output_name = ort.get_outputs()[0].name
210221
if output_name not in res:
211222
raise AssertionError("Unable to find output %r in %r." % (output_name, list(sorted(res))))
@@ -252,10 +263,15 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
252263
# tensorflow
253264
import tensorflow_hub as hub
254265
from tensorflow import convert_to_tensor
266+
if isinstance(imgs[0], OrderedDict):
267+
imgs_tf = [
268+
OrderedDict((k, convert_to_tensor(v)) for k, v in img.items())
269+
for img in imgs]
270+
else:
271+
imgs_tf = [convert_to_tensor(img) for img in imgs]
255272
model = hub.load(url.split("?")[0])
256273
if signature is not None:
257274
model = model.signatures['serving_default']
258-
imgs_tf = [convert_to_tensor(img) for img in imgs]
259275
results_tf, duration_tf = measure_time(model, imgs_tf)
260276

261277
if verbose:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from collections import OrderedDict
4+
import numpy
5+
import numpy.random as rnd
6+
from _tools import generate_random_images, benchmark
7+
8+
9+
def main(opset=13):
10+
url = "https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/4?tf-hub-format=compressed"
11+
dest = "tf-bert-en-wwm-uncased-L-24-H-1024-A-16"
12+
name = "bert-en-wwm-uncased-L-24-H-1024-A-16"
13+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
14+
15+
inputs = [OrderedDict([
16+
('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))),
17+
('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))),
18+
('input_type_ids', numpy.array([i//5 for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1)))
19+
]) for i in range(0, 10)]
20+
21+
benchmark(url, dest, onnx_name, opset, inputs)
22+
23+
24+
if __name__ == "__main__":
25+
main()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
import numpy
4+
from _tools import generate_random_images, benchmark
5+
6+
7+
def main(opset=13):
8+
url = "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5?tf-hub-format=compressed"
9+
dest = "tf-mobilenet-v3-large-075-224"
10+
name = "mobilenet-v3-large-075-224"
11+
onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12+
13+
imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14+
15+
benchmark(url, dest, onnx_name, opset, imgs)
16+
17+
18+
if __name__ == "__main__":
19+
main()

0 commit comments

Comments
 (0)