19
19
import tempfile
20
20
import time
21
21
import zipfile
22
+ import random
22
23
from collections import namedtuple
23
24
from distutils .version import LooseVersion
24
25
26
+
25
27
import yaml
26
28
import numpy as np
27
29
import PIL .Image
38
40
# not needed for tf-2.0
39
41
pass
40
42
41
- from tf2onnx import tf_loader , logging , optimizer , utils , tf_utils
43
+ from tf2onnx import tf_loader , logging , optimizer , utils , tf_utils , constants
42
44
from tf2onnx .tfonnx import process_tf_graph
43
45
from tf2onnx .tf_loader import tf_session , tf_reset_default_graph
44
46
from tf2onnx .graph import ExternalTensorStorage
@@ -62,11 +64,13 @@ def get_beach(shape):
62
64
63
65
def get_random (shape ):
64
66
"""Get random input."""
67
+ np .random .seed (42 )
65
68
return np .random .sample (shape ).astype (np .float32 )
66
69
67
70
68
71
def get_random256 (shape ):
69
72
"""Get random imput between 0 and 255."""
73
+ np .random .seed (42 )
70
74
return np .round (np .random .sample (shape ) * 256 ).astype (np .float32 )
71
75
72
76
@@ -98,6 +102,7 @@ def get_ones_int32(shape):
98
102
99
103
def get_small_rand_int32 (shape ):
100
104
"""Get random ints in range [1, 99]"""
105
+ np .random .seed (42 )
101
106
return np .random .randint (low = 1 , high = 100 , size = shape , dtype = np .int32 )
102
107
103
108
def get_zeros_then_ones (shape ):
@@ -111,6 +116,15 @@ def get_wav(shape):
111
116
"""Get sound data."""
112
117
return np .sin (np .linspace (- np .pi , np .pi , shape [0 ]), dtype = np .float32 )
113
118
119
+ def get_sentences (shape ):
120
+ """Get sentences of shape"""
121
+ words = "the quick brown fox jumps over a lazy dog" .split (' ' )
122
+ random .seed (42 )
123
+ def get_sentence ():
124
+ length = random .randint (2 , 7 )
125
+ return ' ' .join (random .choice (words ) for _ in range (length ))
126
+ return np .array ([get_sentence () for _ in range (np .product (shape ))]).reshape (shape )
127
+
114
128
115
129
_INPUT_FUNC_MAPPING = {
116
130
"get_beach" : get_beach ,
@@ -124,7 +138,8 @@ def get_wav(shape):
124
138
"get_zeros_int64" : get_zeros_int64 ,
125
139
"get_ones_int32" : get_ones_int32 ,
126
140
"get_small_rand_int32" : get_small_rand_int32 ,
127
- "get_zeros_then_ones" : get_zeros_then_ones
141
+ "get_zeros_then_ones" : get_zeros_then_ones ,
142
+ "get_sentences" : get_sentences ,
128
143
}
129
144
130
145
@@ -142,14 +157,18 @@ def __init__(self, url, local, input_func, input_names, output_names,
142
157
check_only_shape = False , model_type = "frozen" , force_input_shape = False ,
143
158
skip_tensorflow = False , opset_constraints = None , tf_min_version = None , tag = None ,
144
159
skip_conversion = False , converted_model = None , signature_def = None , concrete_function = None ,
145
- large_model = False , structured_outputs = None ):
160
+ large_model = False , structured_outputs = None , run_tf_frozen = None , use_custom_ops = False ):
146
161
self .url = url
147
162
self .input_func = input_func
148
163
self .local = local
149
164
self .input_names = input_names
150
165
self .output_names = output_names
151
166
self .disabled = disabled
152
167
self .large_model = large_model
168
+ self .use_custom_ops = use_custom_ops
169
+ if run_tf_frozen is None :
170
+ run_tf_frozen = not self .large_model
171
+ self .run_tf_frozen = run_tf_frozen
153
172
self .structured_outputs = structured_outputs # Needed to determine output order for tf_function
154
173
self .rtol = rtol
155
174
self .atol = atol
@@ -242,12 +261,17 @@ def run_tensorflow(self, sess, inputs):
242
261
return result
243
262
244
263
def to_onnx (self , tf_graph , opset = None , extra_opset = None , shape_override = None , input_names = None ,
245
- const_node_values = None ):
264
+ const_node_values = None , initialized_tables = None ):
246
265
"""Convert graph to tensorflow."""
266
+ if extra_opset is None :
267
+ extra_opset = []
268
+ if self .use_custom_ops :
269
+ extra_opset .append (utils .make_opsetid (constants .CONTRIB_OPS_DOMAIN , 1 ))
247
270
return process_tf_graph (tf_graph , continue_on_error = False , opset = opset ,
248
271
extra_opset = extra_opset , target = Test .target , shape_override = shape_override ,
249
272
input_names = input_names , output_names = self .output_names ,
250
- const_node_values = const_node_values )
273
+ const_node_values = const_node_values ,
274
+ initialized_tables = initialized_tables )
251
275
252
276
def run_caffe2 (self , name , model_proto , inputs ):
253
277
"""Run test again caffe2 backend."""
@@ -268,7 +292,13 @@ def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=Non
268
292
as_text = utils .is_debug_mode (),
269
293
external_tensor_storage = external_tensor_storage )
270
294
logger .info ("Model saved to %s" , model_path )
271
- m = rt .InferenceSession (model_path )
295
+ if self .use_custom_ops :
296
+ from ortcustomops import get_library_path
297
+ opt = rt .SessionOptions ()
298
+ opt .register_custom_ops_library (get_library_path ())
299
+ m = rt .InferenceSession (model_path , opt )
300
+ else :
301
+ m = rt .InferenceSession (model_path )
272
302
results = m .run (self .output_names , inputs )
273
303
if self .perf :
274
304
start = time .time ()
@@ -303,19 +333,21 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
303
333
304
334
logger .info ("Load model from %s" , model_path )
305
335
input_names = list (self .input_names .keys ())
336
+ initialized_tables = {}
306
337
outputs = self .output_names
307
338
if self .model_type in ["checkpoint" ]:
308
339
graph_def , input_names , outputs = tf_loader .from_checkpoint (model_path , input_names , outputs )
309
340
elif self .model_type in ["saved_model" ]:
310
341
loaded = tf_loader .from_saved_model (model_path , input_names , outputs , self .tag , self .signatures ,
311
342
self .concrete_function , self .large_model ,
312
- return_concrete_func = self .large_model )
313
- if self .large_model :
343
+ return_concrete_func = not self .run_tf_frozen ,
344
+ return_initialized_tables = True )
345
+ if not self .run_tf_frozen :
314
346
# Must maintain ref to imported since concrete_func uses weak refs
315
347
# pylint: disable=unused-variable
316
- graph_def , input_names , outputs , concrete_func , imported = loaded
348
+ graph_def , input_names , outputs , concrete_func , imported , initialized_tables = loaded
317
349
else :
318
- graph_def , input_names , outputs = loaded
350
+ graph_def , input_names , outputs , initialized_tables = loaded
319
351
elif self .model_type in ["keras" ]:
320
352
graph_def , input_names , outputs = tf_loader .from_keras (model_path , input_names , outputs )
321
353
else :
@@ -324,7 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
324
356
if utils .is_debug_mode ():
325
357
utils .save_protobuf (os .path .join (TEMP_DIR , name + "_after_tf_optimize.pb" ), graph_def )
326
358
327
- if self .large_model :
359
+ if not self .run_tf_frozen :
328
360
inputs = {}
329
361
for k in input_names :
330
362
v = self .input_names [k ]
@@ -368,7 +400,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
368
400
np_value .dtype )
369
401
inputs [k ] = np_value .astype (expected_dtype )
370
402
else :
371
- inputs [k ] = self .make_input (v ).astype (expected_dtype )
403
+ if expected_dtype == "string" :
404
+ inputs [k ] = self .make_input (v ).astype (np .str ).astype (np .object )
405
+ else :
406
+ inputs [k ] = self .make_input (v ).astype (expected_dtype )
372
407
373
408
if self .force_input_shape :
374
409
for k , v in inputs .items ():
@@ -377,7 +412,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
377
412
# run the model with tensorflow
378
413
if self .skip_tensorflow :
379
414
logger .info ("TensorFlow SKIPPED" )
380
- elif not self .large_model :
415
+ elif self .run_tf_frozen :
381
416
tf_results = self .run_tensorflow (sess , inputs )
382
417
logger .info ("TensorFlow OK" )
383
418
@@ -395,7 +430,8 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
395
430
# convert model to onnx
396
431
onnx_graph = self .to_onnx (sess .graph , opset = opset , extra_opset = extra_opset ,
397
432
shape_override = shape_override , input_names = inputs .keys (),
398
- const_node_values = const_node_values )
433
+ const_node_values = const_node_values ,
434
+ initialized_tables = initialized_tables )
399
435
onnx_graph = optimizer .optimize_graph (onnx_graph )
400
436
print ("ONNX" , onnx_graph .dump_node_statistics ())
401
437
external_tensor_storage = ExternalTensorStorage () if self .large_model else None
@@ -559,7 +595,8 @@ def load_tests_from_yaml(path):
559
595
kwargs = {}
560
596
for kw in ["rtol" , "atol" , "disabled" , "check_only_shape" , "model_type" , "concrete_function" ,
561
597
"skip_tensorflow" , "force_input_shape" , "tf_min_version" , "tag" , "skip_conversion" ,
562
- "converted_model" , "signature_def" , "large_model" , "structured_outputs" ]:
598
+ "converted_model" , "signature_def" , "large_model" , "structured_outputs" , "run_tf_frozen" ,
599
+ "use_custom_ops" ]:
563
600
if settings .get (kw ) is not None :
564
601
kwargs [kw ] = settings [kw ]
565
602
0 commit comments