From 684fb67d87bc9c711bf3ec74ac3f279d8e0e6ca5 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 26 May 2021 10:23:05 -0700 Subject: [PATCH 1/2] add script to run optimizer on onnx files Signed-off-by: Guenther Schmuelling --- tools/onnx-optimize.py | 69 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tools/onnx-optimize.py diff --git a/tools/onnx-optimize.py b/tools/onnx-optimize.py new file mode 100644 index 000000000..8a68dd9ff --- /dev/null +++ b/tools/onnx-optimize.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 + + +""" +A simple tool to try optimizations on onnx graphs. +This makes use of the fact that tensorflow-onnx internal graph representation is onnx +so all graph, rewrite, matching and utility libaries do work which makes things easy. +""" + +# pylint: disable=invalid-name,missing-docstring, unused-argument + +import argparse +import logging +import traceback + +import numpy as np +import onnx +from onnx import helper + +from tf2onnx.graph import GraphUtil +from tf2onnx import logging, optimizer, utils + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("onnx-optimize") + + +def get_args(): + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="onnx input model file") + parser.add_argument("--output", help="output model file") + args = parser.parse_args() + return args + + +def load_graph(fname): + model_proto = onnx.ModelProto() + with open(fname, "rb") as f: + data = f.read() + model_proto.ParseFromString(data) + g = GraphUtil.create_graph_from_onnx_model(model_proto) + return g, model_proto + + +def main(): + args = get_args() + + g, org_model_proto = load_graph(args.input) + + g = optimizer.optimize_graph(g) + + onnx_graph = g.make_graph(org_model_proto.graph.doc_string + " (+tf2onnx/onnx-optimize)") + + kwargs = {"producer_name": org_model_proto.producer_name, + "producer_version": org_model_proto.producer_version, + "opset_imports": org_model_proto.opset_import, + "ir_version": org_model_proto.ir_version} + + model_proto = helper.make_model(onnx_graph, **kwargs) + + # write onnx graph + if args.output: + with open(args.output, "wb") as f: + f.write(model_proto.SerializeToString()) + + +if __name__ == "__main__": + main() From 3a090f7579f522898b74591f952b7ffd53eca75e Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 26 May 2021 13:37:09 -0700 Subject: [PATCH 2/2] pylint Signed-off-by: Guenther Schmuelling --- tools/onnx-optimize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tools/onnx-optimize.py b/tools/onnx-optimize.py index 8a68dd9ff..1fd5ff092 100644 --- a/tools/onnx-optimize.py +++ b/tools/onnx-optimize.py @@ -11,14 +11,12 @@ import argparse import logging -import traceback -import numpy as np import onnx from onnx import helper from tf2onnx.graph import GraphUtil -from tf2onnx import logging, optimizer, utils +from tf2onnx import logging, optimizer logging.basicConfig(level=logging.INFO)