Skip to content

Commit 143db68

Browse files
Add target for converting to channels last format
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent cd64e4e commit 143db68

File tree

7 files changed

+165
-14
lines changed

7 files changed

+165
-14
lines changed

tf2onnx/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
TARGET_RS6 = "rs6"
3030
TARGET_CAFFE2 = "caffe2"
3131
TARGET_TENSORRT = "tensorrt"
32-
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2, TARGET_TENSORRT]
32+
TARGET_CHANNELS_LAST = "nhwc"
33+
TARGET_CHANNELS_FIRST = "nchw"
34+
POSSIBLE_TARGETS = [TARGET_RS4, TARGET_RS5, TARGET_RS6, TARGET_CAFFE2, TARGET_TENSORRT, TARGET_CHANNELS_LAST]
3335
DEFAULT_TARGET = []
3436

3537
NCHW_TO_NHWC = [0, 2, 3, 1]

tf2onnx/graph.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,13 +1124,29 @@ def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
11241124
# create output_tensor_values
11251125
output_tensor_values = self.make_onnx_graph_io(self.outputs)
11261126

1127+
tensor_value_info = []
1128+
1129+
for op in ops:
1130+
if op.domain in [constants.ONNX_DOMAIN, constants.AI_ONNX_ML_DOMAIN]:
1131+
continue
1132+
# We still don't 100% trust the accuracy of all the shapes in graph.py, but for custom ops they are
1133+
# almost certainly accurate and onnx has no other way of knowing them.
1134+
for out in op.output:
1135+
if out == '' or out in self.outputs:
1136+
continue
1137+
dtype = self.get_dtype(out)
1138+
shape = self.get_shape(out)
1139+
v = utils.make_onnx_inputs_outputs(out, dtype, shape)
1140+
tensor_value_info.append(v)
1141+
11271142
# create graph proto
11281143
graph = helper.make_graph([op.op for op in ops],
11291144
graph_name,
11301145
input_tensor_values,
11311146
output_tensor_values,
11321147
initializer=initializers,
1133-
doc_string=doc)
1148+
doc_string=doc,
1149+
value_info=tensor_value_info)
11341150

11351151
return graph
11361152

@@ -1628,10 +1644,11 @@ def get_onnx_model_properties(onnx_model_proto):
16281644
return kwargs
16291645

16301646
@staticmethod
1631-
def create_graph_from_onnx_model(onnx_model_proto):
1647+
def create_graph_from_onnx_model(onnx_model_proto, target=None):
16321648
"""Create Graph loading onnx model proto."""
16331649
# apply shape inference on the model
16341650
inferred_model = shape_inference.infer_shapes(onnx_model_proto)
1651+
utils.initialize_name_counter(inferred_model)
16351652
graph_proto = inferred_model.graph
16361653

16371654
opset_version = None
@@ -1644,11 +1661,11 @@ def create_graph_from_onnx_model(onnx_model_proto):
16441661
extra_opset.append(opset)
16451662

16461663
utils.make_sure(opset_version is not None, "opset version is not specified for onnx domain")
1647-
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version, extra_opset)
1664+
main_graph = GraphUtil.create_graph_from_onnx_graph(graph_proto, opset_version, extra_opset, target)
16481665
return main_graph
16491666

16501667
@staticmethod
1651-
def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=None):
1668+
def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=None, target=None):
16521669
"""Create Graph loading onnx graph proto."""
16531670
output_shapes = {}
16541671
output_dtypes = {}
@@ -1675,7 +1692,7 @@ def create_graph_from_onnx_graph(graph_proto, opset_version=None, extra_opset=No
16751692
for n in graph_proto.output:
16761693
output_names.append(n.name)
16771694

1678-
g = Graph(nodes_to_append, output_shapes, output_dtypes, None, opset_version, extra_opset, None, output_names)
1695+
g = Graph(nodes_to_append, output_shapes, output_dtypes, target, opset_version, extra_opset, None, output_names)
16791696
const_nodes = GraphUtil._parse_graph_initializer(g, graph_proto)
16801697
GraphUtil._parse_graph_input(g, graph_proto, [n.name for n in const_nodes])
16811698

@@ -1702,6 +1719,10 @@ def _parse_shape_and_type_from_value_infos(value_infos):
17021719
for shape_info in value_infos:
17031720
type_proto = shape_info.type
17041721
elem_type = type_proto.tensor_type.elem_type
1722+
output_dtypes[shape_info.name] = elem_type
1723+
if not type_proto.tensor_type.HasField("shape"):
1724+
output_shapes[shape_info.name] = None
1725+
continue
17051726
shape = type_proto.tensor_type.shape
17061727
tuned_shape = []
17071728
for d in shape.dim:
@@ -1713,7 +1734,6 @@ def _parse_shape_and_type_from_value_infos(value_infos):
17131734
# it is found, some unknown dims is missing after inference.
17141735
tuned_shape.append(-1)
17151736
output_shapes[shape_info.name] = tuned_shape
1716-
output_dtypes[shape_info.name] = elem_type
17171737

17181738
return output_shapes, output_dtypes
17191739

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
TODO: fill this
6+
"""
7+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
8+
from tf2onnx import utils, constants
9+
10+
_CHANNELS_FIRST_OPS = [
11+
"AveragePool",
12+
"BatchNormalization",
13+
"Conv",
14+
"ConvInteger",
15+
"ConvTranspose",
16+
"GlobalAveragePool",
17+
"GlobalLpPool",
18+
"GlobalMaxPool",
19+
"InstanceNormalization",
20+
"LpPool",
21+
"LRN",
22+
"MaxPool",
23+
"MaxRoiPool",
24+
"MaxUnpool",
25+
"QLinearConv",
26+
]
27+
28+
def channel_last_to_first_perm(rank):
29+
return [0, rank - 1] + list(range(1, rank - 1))
30+
31+
def channel_first_to_last_perm(rank):
32+
return [0] + list(range(2, rank)) + [1]
33+
34+
def _to_channel_last_handler(g, op):
35+
# For now, all ops can use the same handlers (input[0] and output[0] are always correct)
36+
rank = g.get_rank(op.output[0])
37+
utils.make_sure(rank is not None, "Cannot convert %s node %s with unknown rank to channels last", op.type, op.name)
38+
op.type = "ChannelsLast" + op.type
39+
op.domain = constants.CONTRIB_OPS_DOMAIN
40+
inp_perm = channel_first_to_last_perm(rank)
41+
out_perm = channel_last_to_first_perm(rank)
42+
output_shape = g.get_shape(op.output[0])
43+
if output_shape is not None:
44+
output_shape = [output_shape[i] for i in inp_perm]
45+
g.set_shape(op.output[0], output_shape)
46+
47+
g.insert_new_node_on_input(op, "Transpose", op.input[0], input_index=0, perm=inp_perm)
48+
g.insert_new_node_on_output("Transpose", op.output[0], perm=out_perm)
49+
50+
def _to_channel_first_handler(g, op):
51+
rank = g.get_rank(op.output[0])
52+
utils.make_sure(rank is not None, "Cannot convert %s node %s with unknown rank to channels last", op.type, op.name)
53+
op.type = op.type.replace("ChannelsLast", "")
54+
op.domain = constants.ONNX_DOMAIN
55+
inp_perm = channel_last_to_first_perm(rank)
56+
out_perm = channel_first_to_last_perm(rank)
57+
output_shape = g.get_shape(op.output[0])
58+
if output_shape is not None:
59+
output_shape = [output_shape[i] for i in inp_perm]
60+
g.set_shape(op.output[0], output_shape)
61+
62+
g.insert_new_node_on_input(op, "Transpose", op.input[0], input_index=0, perm=inp_perm)
63+
g.insert_new_node_on_output("Transpose", op.output[0], perm=out_perm)
64+
65+
def get_channels_first_ops(opset=None):
66+
# opset doesn't matter for now
67+
return set(_CHANNELS_FIRST_OPS)
68+
69+
70+
# pylint: disable=missing-docstring
71+
72+
def rewrite_channels_last(g, ops):
73+
channel_first_ops = get_channels_first_ops(g.opset)
74+
for op in ops:
75+
if op.type in channel_first_ops:
76+
_to_channel_last_handler(g, op)
77+
return g.get_nodes()
78+
79+
def rewrite_channels_first(g, ops):
80+
for op in ops:
81+
if op.type.startswith("ChannelsLast"):
82+
_to_channel_first_handler(g, op)
83+
return g.get_nodes()

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
import onnx
10-
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW, NCDHW_TO_NDHWC, NDHWC_TO_NCDHW
10+
from tf2onnx.constants import NCHW_TO_NHWC, NHWC_TO_NCHW, NCDHW_TO_NDHWC, NDHWC_TO_NCDHW, TARGET_CHANNELS_LAST
1111
from .. import utils
1212
from .optimizer_base import GraphOptimizerBase
1313

@@ -362,14 +362,19 @@ def _should_push_transpose(self, trans, node):
362362
perm = trans.get_attr_value("perm")
363363
optimization_gains = 0
364364
removed_nchws = 0
365+
perm_to_push_down = [NCHW_TO_NHWC, NCDHW_TO_NDHWC]
366+
perm_to_push_up = [NHWC_TO_NCHW, NDHWC_TO_NCDHW]
367+
if self._g.is_target(TARGET_CHANNELS_LAST):
368+
perm_to_push_down, perm_to_push_up = perm_to_push_up, perm_to_push_down
369+
365370
for n, inp_id in zip(node.inputs, node.input):
366371
if is_tranpose_of_type(n, perm):
367372
optimization_gains += self._cost_to_transpose(n.inputs[0], n.input[0])
368-
if perm in [NCHW_TO_NHWC, NCDHW_TO_NDHWC]:
373+
if perm in perm_to_push_down:
369374
removed_nchws += 1
370375
else:
371376
optimization_gains -= self._cost_to_transpose(n, inp_id)
372-
if perm in [NHWC_TO_NCHW, NDHWC_TO_NCDHW]:
377+
if perm in perm_to_push_up:
373378
removed_nchws -= 1
374379
if removed_nchws != 0:
375380
# Always push nchw transposes if possible

tf2onnx/tfonnx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx.graph import Graph
2020
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
2121
from tf2onnx.tflite_rewriters import * # pylint: disable=wildcard-import
22+
from tf2onnx.late_rewriters.channel_order_rewriters import rewrite_channels_last
2223
from tf2onnx.shape_inference import infer_shape
2324
from tf2onnx.tf_loader import is_function, resolve_functions, set_function
2425
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version, compute_const_folding_using_tf
@@ -640,6 +641,8 @@ def compat_handler(ctx, node, **kwargs):
640641
late_rewriters.append(rewrite_incomplete_type_support_rs5)
641642
if constants.TARGET_RS6 in target:
642643
late_rewriters.append(rewrite_incomplete_type_support_rs6)
644+
if constants.TARGET_CHANNELS_LAST in target:
645+
late_rewriters.append(rewrite_channels_last)
643646
if late_rewriters:
644647
run_rewriters(g, late_rewriters, continue_on_error)
645648

tf2onnx/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,34 @@ def find_opset(opset):
184184
return opset
185185

186186

187+
def get_subgraphs_from_onnx(model_proto):
188+
stack = [model_proto.graph]
189+
while stack:
190+
g = stack.pop()
191+
yield g
192+
for node in g.node:
193+
for attr in node.attribute:
194+
if hasattr(attr, "g"):
195+
stack.append(attr.g)
196+
if hasattr(attr, "graphs"):
197+
stack.extend(attr.graphs)
198+
199+
200+
def initialize_name_counter(model_proto):
201+
"""Avoid name conflicts by initializing the counter used by make_name based on the provided model"""
202+
suffix_regex = re.compile(r"__(\d+)(:\d+)?$")
203+
def avoid_name(name):
204+
global INTERNAL_NAME
205+
suffix = suffix_regex.search(name)
206+
if suffix:
207+
INTERNAL_NAME = max(INTERNAL_NAME, int(suffix.group(1)) + 1)
208+
for g in get_subgraphs_from_onnx(model_proto):
209+
for n in g.node:
210+
avoid_name(n.name)
211+
for out in n.output:
212+
avoid_name(out)
213+
214+
187215
def save_onnx_model(save_path_root, onnx_file_name, feed_dict, model_proto, include_test_data=False, as_text=False,
188216
external_tensor_storage=None):
189217
"""Save onnx model as file. Save a pbtxt file as well if as_text is True"""

tools/onnx-optimize.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from onnx import helper
1717

1818
from tf2onnx.graph import GraphUtil
19-
from tf2onnx import logging, optimizer
19+
from tf2onnx import logging, optimizer, constants
20+
from tf2onnx.late_rewriters.channel_order_rewriters import rewrite_channels_first, rewrite_channels_last
2021

2122

2223
logging.basicConfig(level=logging.INFO)
@@ -28,23 +29,32 @@ def get_args():
2829
parser = argparse.ArgumentParser()
2930
parser.add_argument("--input", required=True, help="onnx input model file")
3031
parser.add_argument("--output", help="output model file")
32+
target_options = [constants.TARGET_CHANNELS_LAST, constants.TARGET_CHANNELS_FIRST]
33+
parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=target_options,
34+
help="target platform")
3135
args = parser.parse_args()
36+
args.target = args.target.split(",")
3237
return args
3338

3439

35-
def load_graph(fname):
40+
def load_graph(fname, target):
3641
model_proto = onnx.ModelProto()
3742
with open(fname, "rb") as f:
3843
data = f.read()
3944
model_proto.ParseFromString(data)
40-
g = GraphUtil.create_graph_from_onnx_model(model_proto)
45+
g = GraphUtil.create_graph_from_onnx_model(model_proto, target)
4146
return g, model_proto
4247

4348

4449
def main():
4550
args = get_args()
4651

47-
g, org_model_proto = load_graph(args.input)
52+
g, org_model_proto = load_graph(args.input, args.target)
53+
54+
if g.is_target(constants.TARGET_CHANNELS_FIRST):
55+
g.reset_nodes(rewrite_channels_first(g, g.get_nodes()))
56+
if g.is_target(constants.TARGET_CHANNELS_LAST):
57+
g.reset_nodes(rewrite_channels_last(g, g.get_nodes()))
4858

4959
g = optimizer.optimize_graph(g)
5060

0 commit comments

Comments
 (0)