Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 112 additions & 36 deletions tests/run_pretrained_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tf2onnx import tf_loader, logging, optimizer, utils, tf_utils
from tf2onnx.tfonnx import process_tf_graph
from tf2onnx.tf_loader import tf_session, tf_reset_default_graph
from tf2onnx.graph import ExternalTensorStorage

logger = logging.getLogger("run_pretrained")

Expand Down Expand Up @@ -102,16 +103,20 @@ class Test(object):
cache_dir = None
target = []

def __init__(self, url, local, make_input, input_names, output_names,
def __init__(self, url, local, input_func, input_names, output_names,
disabled=False, rtol=0.01, atol=1e-6,
check_only_shape=False, model_type="frozen", force_input_shape=False,
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None):
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
large_model=False, structured_outputs=None):
self.url = url
self.make_input = make_input
self.input_func = input_func
self.local = local
self.input_names = input_names
self.output_names = output_names
self.disabled = disabled
self.large_model = large_model
self.structured_outputs = structured_outputs # Needed to determine output order for tf_function
self.rtol = rtol
self.atol = atol
self.check_only_shape = check_only_shape
Expand All @@ -122,8 +127,18 @@ def __init__(self, url, local, make_input, input_names, output_names,
self.tag = tag
self.force_input_shape = force_input_shape
self.skip_tensorflow = skip_tensorflow
self.skip_conversion = skip_conversion
self.converted_model = converted_model
self.opset_constraints = opset_constraints
self.tf_min_version = tf_min_version
self.signatures = [signature_def] if signature_def else None
self.concrete_function = concrete_function

def make_input(self, v):
"""Allows each input to specify its own function while defaulting to the input_get function"""
if isinstance(v, dict):
return _INPUT_FUNC_MAPPING[v["input_get"]](v["shape"])
return self.input_func(v)

def download_model(self):
"""Download model from url."""
Expand All @@ -149,7 +164,7 @@ def download_model(self):
if not os.path.exists(fpath):
utils.get_url(url, fpath)
model_path = os.path.join(dir_name, self.local)
if not os.path.exists(model_path):
if not os.path.exists(model_path) or self.local == ".":
if ftype == 'tgz':
tar = tarfile.open(fpath)
tar.extractall(dir_name)
Expand Down Expand Up @@ -179,19 +194,23 @@ def run_tensorflow(self, sess, inputs):
for k, v in inputs.items():
k = sess.graph.get_tensor_by_name(k)
feed_dict[k] = v
logger.info("Running TF")
result = sess.run(self.output_names, feed_dict=feed_dict)
if self.perf:
logger.info("Running TF perf")
start = time.time()
for _ in range(PERFITER):
_ = sess.run(self.output_names, feed_dict=feed_dict)
self.tf_runtime = time.time() - start
return result

def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None,
const_node_values=None):
"""Convert graph to tensorflow."""
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
input_names=input_names, output_names=self.output_names)
input_names=input_names, output_names=self.output_names,
const_node_values=const_node_values)

def run_caffe2(self, name, model_proto, inputs):
"""Run test again caffe2 backend."""
Expand All @@ -205,11 +224,12 @@ def run_caffe2(self, name, model_proto, inputs):
self.onnx_runtime = time.time() - start
return results

def run_onnxruntime(self, name, model_proto, inputs):
def run_onnxruntime(self, name, model_proto, inputs, external_tensor_storage=None):
"""Run test against onnxruntime backend."""
import onnxruntime as rt
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True,
as_text=utils.is_debug_mode())
as_text=utils.is_debug_mode(),
external_tensor_storage=external_tensor_storage)
logger.info("Model saved to %s", model_path)
m = rt.InferenceSession(model_path)
results = m.run(self.output_names, inputs)
Expand All @@ -221,10 +241,14 @@ def run_onnxruntime(self, name, model_proto, inputs):
return results

@staticmethod
def create_onnx_file(name, model_proto, inputs, outdir):
def create_onnx_file(name, model_proto, inputs, outdir, external_tensor_storage=None):
os.makedirs(outdir, exist_ok=True)
model_path = os.path.join(outdir, name + ".onnx")
utils.save_protobuf(model_path, model_proto)
if external_tensor_storage is None:
model_path = os.path.join(outdir, name + ".onnx")
utils.save_protobuf(model_path, model_proto)
else:
model_path = os.path.join(outdir, name + ".zip")
utils.save_onnx_zip(model_path, model_proto, external_tensor_storage)
logger.info("Created %s", model_path)

def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None,
Expand All @@ -236,7 +260,7 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if self.url:
_, dir_name = self.download_model()
logger.info("Downloaded to %s", dir_name)
model_path = os.path.join(dir_name, self.local)
model_path = os.path.join(dir_name, self.local) if self.local != "." else dir_name
else:
model_path = self.local

Expand All @@ -246,13 +270,15 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
try:
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
except OSError:
model_path = dir_name
logger.info("Load model(2) from %r", model_path)
res = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag)
graph_def, input_names, outputs = res[:3]
loaded = tf_loader.from_saved_model(model_path, input_names, outputs, self.tag, self.signatures,
self.concrete_function, self.large_model,
return_concrete_func=self.large_model)
if self.large_model:
# Must maintain ref to imported since concrete_func uses weak refs
# pylint: disable=unused-variable
graph_def, input_names, outputs, concrete_func, imported = loaded
else:
graph_def, input_names, outputs = loaded
elif self.model_type in ["keras"]:
graph_def, input_names, outputs = tf_loader.from_keras(model_path, input_names, outputs)
else:
Expand All @@ -261,9 +287,34 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
if utils.is_debug_mode():
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)

if self.large_model:
inputs = {}
for k in input_names:
v = self.input_names[k]
inputs[k.split(":")[0]] = tf.constant(self.make_input(v))
tf_func = tf.function(concrete_func)
logger.info("Running TF")
tf_results_d = tf_func(**inputs)
if self.structured_outputs is None:
tf_results = list(tf_results_d.values())
else:
tf_results = [tf_results_d[output] for output in self.structured_outputs]
if self.perf:
logger.info("Running TF perf")
start = time.time()
for _ in range(PERFITER):
_ = concrete_func(**inputs)
self.tf_runtime = time.time() - start
logger.info("TensorFlow OK")

inputs = {}
shape_override = {}
tf_reset_default_graph()

from tf2onnx.tf_utils import compress_graph_def
const_node_values = None
if self.large_model:
const_node_values = compress_graph_def(graph_def)
g = tf.import_graph_def(graph_def, name='')
# with tf_session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
with tf_session(graph=g) as sess:
Expand All @@ -288,30 +339,50 @@ def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_ops
# run the model with tensorflow
if self.skip_tensorflow:
logger.info("TensorFlow SKIPPED")
else:
elif not self.large_model:
tf_results = self.run_tensorflow(sess, inputs)
logger.info("TensorFlow OK")

model_proto = None
try:
# convert model to onnx
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
shape_override=shape_override, input_names=inputs.keys())
onnx_graph = optimizer.optimize_graph(onnx_graph)
model_proto = onnx_graph.make_model("converted from tf2onnx")
logger.info("To_ONNX, OK")
if onnx_file:
self.create_onnx_file(name, model_proto, inputs, onnx_file)
except Exception:
logger.error("To_ONNX FAIL", exc_info=1)
return False
if self.skip_conversion:
if self.large_model:
external_tensor_storage = ExternalTensorStorage()
model_proto = utils.model_proto_from_zip(self.converted_model, external_tensor_storage)
else:
external_tensor_storage = None
model_proto = utils.model_proto_from_file(self.converted_model)
logger.info("ONNX loaded from file")
else:
try:
# convert model to onnx
onnx_graph = self.to_onnx(sess.graph, opset=opset, extra_opset=extra_opset,
shape_override=shape_override, input_names=inputs.keys(),
const_node_values=const_node_values)
onnx_graph = optimizer.optimize_graph(onnx_graph)
print("ONNX", onnx_graph.dump_node_statistics())
external_tensor_storage = ExternalTensorStorage() if self.large_model else None
model_proto = onnx_graph.make_model("converted from tf2onnx",
external_tensor_storage=external_tensor_storage)
logger.info("To_ONNX, OK")
if onnx_file:
self.create_onnx_file(name, model_proto, inputs, onnx_file, external_tensor_storage)
if self.converted_model:
if self.large_model:
utils.save_onnx_zip(self.converted_model, model_proto, external_tensor_storage)
else:
utils.save_protobuf(self.converted_model, model_proto)
logger.info("Created %s", self.converted_model)

except Exception:
logger.error("To_ONNX FAIL", exc_info=1)
return False

try:
onnx_results = None
if backend == "caffe2":
onnx_results = self.run_caffe2(name, model_proto, inputs)
elif backend == "onnxruntime":
onnx_results = self.run_onnxruntime(name, model_proto, inputs)
onnx_results = self.run_onnxruntime(name, model_proto, inputs, external_tensor_storage)
else:
raise ValueError("unknown backend")
logger.info("Run_ONNX OK")
Expand Down Expand Up @@ -390,6 +461,7 @@ def get_args():
parser.add_argument("--list", help="list tests", action="store_true")
parser.add_argument("--onnx-file", help="create onnx file in directory")
parser.add_argument("--perf", help="capture performance numbers")
parser.add_argument("--perfiter", type=int, default=PERFITER, help="number of inferences for perf testing")
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
action="store_true")
parser.add_argument("--include-disabled", help="include disabled tests", action="store_true")
Expand Down Expand Up @@ -447,8 +519,9 @@ def load_tests_from_yaml(path):
opset_constraints.append(c)

kwargs = {}
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type",
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag"]:
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type", "concrete_function",
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
"converted_model", "signature_def", "large_model", "structured_outputs"]:
if settings.get(kw) is not None:
kwargs[kw] = settings[kw]

Expand All @@ -459,6 +532,7 @@ def load_tests_from_yaml(path):


def main():
global PERFITER
args = get_args()
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
if args.debug:
Expand All @@ -477,6 +551,7 @@ def main():

failed = 0
count = 0
PERFITER = args.perfiter
for test in test_keys:
logger.info("===================================")

Expand Down Expand Up @@ -520,7 +595,8 @@ def main():
for test in test_keys:
t = tests[test]
if t.perf:
f.write("{},{},{}\n".format(test, t.tf_runtime, t.onnx_runtime))
# Report perf in ms per inference
f.write("{},{},{}\n".format(test, t.tf_runtime * 1000 / PERFITER, t.onnx_runtime * 1000 / PERFITER))
return failed


Expand Down
2 changes: 1 addition & 1 deletion tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ benchtf-gru:
esrgan-tf2:
# url: https://tfhub.dev/captain-pool/esrgan-tf2/1/esrgan-tf2_1.tar.gz
url: https://github.com/captain-pool/GSOC/releases/download/1.0.0/esrgan.tar.gz
model: ersgan
model: "."
model_type: saved_model
input_get: get_beach
opset_constraints:
Expand Down
9 changes: 6 additions & 3 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,21 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
raise ValueError(err_large_model)
raise e

return frozen_graph, inputs, outputs
return frozen_graph, inputs, outputs, concrete_func, imported


def from_saved_model(model_path, input_names, output_names, tag=None,
signatures=None, concrete_function=None, large_model=False):
signatures=None, concrete_function=None, large_model=False, return_concrete_func=False):
"""Load tensorflow graph from saved_model."""
if signatures is None:
signatures = []
tf_reset_default_graph()
if is_tf2():
frozen_graph, input_names, output_names = \
frozen_graph, input_names, output_names, concrete_func, imported = \
_from_saved_model_v2(model_path, input_names, output_names, tag, signatures, concrete_function, large_model)
if return_concrete_func:
tf_reset_default_graph()
return frozen_graph, input_names, output_names, concrete_func, imported
else:
with tf_session() as sess:
frozen_graph, input_names, output_names = \
Expand Down
18 changes: 17 additions & 1 deletion tf2onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from urllib3.util.retry import Retry
import numpy as np
from google.protobuf import text_format
from onnx import helper, onnx_pb, defs, numpy_helper, __version__
from onnx import helper, onnx_pb, defs, numpy_helper, ModelProto, __version__

from . import constants

Expand Down Expand Up @@ -269,6 +269,22 @@ def save_protobuf(path, message, as_text=False):
with open(path, "wb") as f:
f.write(message.SerializeToString())

def model_proto_from_file(model_path):
model_proto = ModelProto()
with open(model_path, "rb") as f:
model_proto.ParseFromString(f.read())
return model_proto

def model_proto_from_zip(zip_path, external_tensor_storage):
model_proto = ModelProto()
with zipfile.ZipFile(zip_path, 'r') as z:
for n in z.namelist():
f = z.open(n)
if n.endswith(".onnx"):
model_proto.ParseFromString(f.read())
else:
external_tensor_storage.name_to_tensor_data[n] = f.read()
return model_proto

def is_list_or_tuple(obj):
return isinstance(obj, (list, tuple))
Expand Down