Skip to content

Commit 6088587

Browse files
Merge pull request #1193 from onnx/tom/UpdateRunPretrainedScript
Updated run_pretrained_models script to support hash tables and string ops
2 parents e67f5bb + 8b9e041 commit 6088587

File tree

1 file changed

+52
-15
lines changed

1 file changed

+52
-15
lines changed

tests/run_pretrained_models.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import tempfile
2020
import time
2121
import zipfile
22+
import random
2223
from collections import namedtuple
2324
from distutils.version import LooseVersion
2425

26+
2527
import yaml
2628
import numpy as np
2729
import PIL.Image
@@ -38,7 +40,7 @@
3840
# not needed for tf-2.0
3941
pass
4042

41-
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils
43+
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils, constants
4244
from tf2onnx.tfonnx import process_tf_graph
4345
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
4446
from tf2onnx.graph import ExternalTensorStorage
@@ -62,11 +64,13 @@ def get_beach(shape):
6264

6365
def get_random(shape):
6466
"""Get random input."""
67+
np.random.seed(42)
6568
return np.random.sample(shape).astype(np.float32)
6669

6770

6871
def get_random256(shape):
6972
"""Get random imput between 0 and 255."""
73+
np.random.seed(42)
7074
return np.round(np.random.sample(shape) * 256).astype(np.float32)
7175

7276

@@ -98,6 +102,7 @@ def get_ones_int32(shape):
98102

99103
def get_small_rand_int32(shape):
100104
"""Get random ints in range [1, 99]"""
105+
np.random.seed(42)
101106
return np.random.randint(low=1, high=100, size=shape, dtype=np.int32)
102107

103108
def get_zeros_then_ones(shape):
@@ -111,6 +116,15 @@ def get_wav(shape):
111116
"""Get sound data."""
112117
return np.sin(np.linspace(-np.pi, np.pi, shape[0]), dtype=np.float32)
113118

119+
def get_sentences(shape):
120+
"""Get sentences of shape"""
121+
words = "the quick brown fox jumps over a lazy dog".split(' ')
122+
random.seed(42)
123+
def get_sentence():
124+
length = random.randint(2, 7)
125+
return ' '.join(random.choice(words) for _ in range(length))
126+
return np.array([get_sentence() for _ in range(np.product(shape))]).reshape(shape)
127+
114128

115129
_INPUT_FUNC_MAPPING = {
116130
"get_beach": get_beach,
@@ -124,7 +138,8 @@ def get_wav(shape):
124138
"get_zeros_int64": get_zeros_int64,
125139
"get_ones_int32": get_ones_int32,
126140
"get_small_rand_int32": get_small_rand_int32,
127-
"get_zeros_then_ones": get_zeros_then_ones
141+
"get_zeros_then_ones": get_zeros_then_ones,
142+
"get_sentences": get_sentences,
128143
}
129144

130145

@@ -142,14 +157,18 @@ def __init__(self, url, local, input_func, input_names, output_names,
142157
check_only_shape=False, model_type="frozen", force_input_shape=False,
143158
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
144159
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
145-
large_model=False, structured_outputs=None):
160+
large_model=False, structured_outputs=None, run_tf_frozen=None, use_custom_ops=False):
146161
self.url = url
147162
self.input_func = input_func
148163
self.local = local
149164
self.input_names = input_names
150165
self.output_names = output_names
151166
self.disabled = disabled
152167
self.large_model = large_model
168+
self.use_custom_ops = use_custom_ops
169+
if run_tf_frozen is None:
170+
run_tf_frozen = not self.large_model
171+
self.run_tf_frozen = run_tf_frozen
153172
self.structured_outputs = structured_outputs # Needed to determine output order for tf_function
154173
self.rtol = rtol
155174
self.atol = atol
@@ -242,12 +261,17 @@ def run_tensorflow(self, sess, inputs):
242261
return result
243262

244263
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
245-
const_node_values=None):
264+
const_node_values=None, initialized_tables=None):
246265
"""Convert graph to tensorflow."""
266+
if extra_opset is None:
267+
extra_opset = []
268+
if self.use_custom_ops:
269+
extra_opset.append(utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1))
247270
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
248271
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
249272
input_names=input_names, output_names=self.output_names,
250-
const_node_values=const_node_values)
273+
const_node_values=const_node_values,
274+
initialized_tables=initialized_tables)
251275

252276
def run_caffe2(self, name, model_proto, inputs):
253277
"""Run test again caffe2 backend."""
@@ -268,7 +292,13 @@ def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=Non
268292
as_text=utils.is_debug_mode(),
269293
external_tensor_storage=external_tensor_storage)
270294
logger.info("Model saved to %s", model_path)
271-
m = rt.InferenceSession(model_path)
295+
if self.use_custom_ops:
296+
from ortcustomops import get_library_path
297+
opt = rt.SessionOptions()
298+
opt.register_custom_ops_library(get_library_path())
299+
m = rt.InferenceSession(model_path, opt)
300+
else:
301+
m = rt.InferenceSession(model_path)
272302
results = m.run(self.output_names, inputs)
273303
if self.perf:
274304
start = time.time()
@@ -303,19 +333,21 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
303333

304334
logger.info("Load model from %s", model_path)
305335
input_names = list(self.input_names.keys())
336+
initialized_tables = {}
306337
outputs = self.output_names
307338
if self.model_type in ["checkpoint"]:
308339
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
309340
elif self.model_type in ["saved_model"]:
310341
loaded = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag, self.signatures,
311342
self.concrete_function, self.large_model,
312-
return_concrete_func=self.large_model)
313-
if self.large_model:
343+
return_concrete_func=not self.run_tf_frozen,
344+
return_initialized_tables=True)
345+
if not self.run_tf_frozen:
314346
# Must maintain ref to imported since concrete_func uses weak refs
315347
# pylint: disable=unused-variable
316-
graph_def, input_names, outputs, concrete_func, imported = loaded
348+
graph_def, input_names, outputs, concrete_func, imported, initialized_tables = loaded
317349
else:
318-
graph_def, input_names, outputs = loaded
350+
graph_def, input_names, outputs, initialized_tables = loaded
319351
elif self.model_type in ["keras"]:
320352
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
321353
else:
@@ -324,7 +356,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
324356
if utils.is_debug_mode():
325357
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
326358

327-
if self.large_model:
359+
if not self.run_tf_frozen:
328360
inputs = {}
329361
for k in input_names:
330362
v = self.input_names[k]
@@ -368,7 +400,10 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
368400
np_value.dtype)
369401
inputs[k] = np_value.astype(expected_dtype)
370402
else:
371-
inputs[k] = self.make_input(v).astype(expected_dtype)
403+
if expected_dtype == "string":
404+
inputs[k] = self.make_input(v).astype(np.str).astype(np.object)
405+
else:
406+
inputs[k] = self.make_input(v).astype(expected_dtype)
372407

373408
if self.force_input_shape:
374409
for k, v in inputs.items():
@@ -377,7 +412,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
377412
# run the model with tensorflow
378413
if self.skip_tensorflow:
379414
logger.info("TensorFlow SKIPPED")
380-
elif not self.large_model:
415+
elif self.run_tf_frozen:
381416
tf_results = self.run_tensorflow(sess, inputs)
382417
logger.info("TensorFlow OK")
383418

@@ -395,7 +430,8 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
395430
# convert model to onnx
396431
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
397432
shape_override=shape_override, input_names=inputs.keys(),
398-
const_node_values=const_node_values)
433+
const_node_values=const_node_values,
434+
initialized_tables=initialized_tables)
399435
onnx_graph = optimizer.optimize_graph(onnx_graph)
400436
print("ONNX", onnx_graph.dump_node_statistics())
401437
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
@@ -559,7 +595,8 @@ def load_tests_from_yaml(path):
559595
kwargs = {}
560596
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type", "concrete_function",
561597
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
562-
"converted_model", "signature_def", "large_model", "structured_outputs"]:
598+
"converted_model", "signature_def", "large_model", "structured_outputs", "run_tf_frozen",
599+
"use_custom_ops"]:
563600
if settings.get(kw) is not None:
564601
kwargs[kw] = settings[kw]
565602

0 commit comments

Comments
 (0)