11
11
import zipfile
12
12
import subprocess
13
13
import datetime
14
+ from collections import OrderedDict
14
15
import numpy
15
16
from tqdm import tqdm
16
17
import onnxruntime
17
18
18
19
19
- def generate_random_images (shape = (1 , 100 , 100 , 3 ), n = 10 , dtype = numpy .float32 ):
20
+ def generate_random_images (shape = (1 , 100 , 100 , 3 ), n = 10 , dtype = numpy .float32 , scale = 255 ):
20
21
imgs = []
21
22
for i in range (n ):
22
23
sh = shape
23
- img = numpy .clip (numpy .abs (numpy .random .randn (* sh )), 0 , 1 ) * 255
24
+ img = numpy .clip (numpy .abs (numpy .random .randn (* sh )), 0 , 1 ) * scale
24
25
img = img .astype (dtype )
25
26
imgs .append (img )
26
27
return imgs
@@ -180,19 +181,27 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
180
181
print (" {}: {}, {}" .format (a .name , a .type , a .shape ))
181
182
182
183
# onnxruntime
183
- input_name = ort .get_inputs ()[0 ].name
184
- fct_ort = lambda img : ort .run (None , {input_name : img })[0 ]
184
+ if isinstance (imgs [0 ], dict ):
185
+ fct_ort = lambda img : ort .run (None , img )[0 ]
186
+ else :
187
+ input_name = ort .get_inputs ()[0 ].name
188
+ fct_ort = lambda img : ort .run (None , {input_name : img })[0 ]
185
189
results_ort , duration_ort = measure_time (fct_ort , imgs )
186
190
if verbose :
187
191
print ("ORT" , len (imgs ), duration_ort )
188
192
189
193
# tensorflow
190
194
import tensorflow_hub as hub
191
195
from tensorflow import convert_to_tensor
196
+ if isinstance (imgs [0 ], OrderedDict ):
197
+ imgs_tf = [
198
+ OrderedDict ((k , convert_to_tensor (v )) for k , v in img .items ())
199
+ for img in imgs ]
200
+ else :
201
+ imgs_tf = [convert_to_tensor (img ) for img in imgs ]
192
202
model = hub .load (url .split ("?" )[0 ])
193
203
if signature is not None :
194
204
model = model .signatures ['serving_default' ]
195
- imgs_tf = [convert_to_tensor (img ) for img in imgs ]
196
205
results_tf , duration_tf = measure_time (model , imgs_tf )
197
206
198
207
if verbose :
@@ -205,7 +214,9 @@ def benchmark(url, dest, onnx_name, opset, imgs, verbose=True, threshold=1e-3,
205
214
res = model (imgs_tf [0 ])
206
215
if isinstance (res , dict ):
207
216
if len (res ) != 1 :
208
- raise NotImplementedError ("TF output contains more than one output: %r." % res )
217
+ raise NotImplementedError (
218
+ "TF output contains more than one output=%r and output names=%r." % (
219
+ res , [o .name for o in ort .get_outputs ()]))
209
220
output_name = ort .get_outputs ()[0 ].name
210
221
if output_name not in res :
211
222
raise AssertionError ("Unable to find output %r in %r." % (output_name , list (sorted (res ))))
@@ -252,10 +263,15 @@ def benchmark_tflite(url, dest, onnx_name, opset, imgs, verbose=True, threshold=
252
263
# tensorflow
253
264
import tensorflow_hub as hub
254
265
from tensorflow import convert_to_tensor
266
+ if isinstance (imgs [0 ], OrderedDict ):
267
+ imgs_tf = [
268
+ OrderedDict ((k , convert_to_tensor (v )) for k , v in img .items ())
269
+ for img in imgs ]
270
+ else :
271
+ imgs_tf = [convert_to_tensor (img ) for img in imgs ]
255
272
model = hub .load (url .split ("?" )[0 ])
256
273
if signature is not None :
257
274
model = model .signatures ['serving_default' ]
258
- imgs_tf = [convert_to_tensor (img ) for img in imgs ]
259
275
results_tf , duration_tf = measure_time (model , imgs_tf )
260
276
261
277
if verbose :
0 commit comments