Skip to content

Commit d46fb97

Browse files
Added handlers for tflite
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 99eb959 commit d46fb97

File tree

6 files changed

+541
-0
lines changed

6 files changed

+541
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
"""tf2onnx.tflite_handlers module"""
4+
5+
from . import (
6+
tfl_math,
7+
tfl_nn,
8+
tfl_controlflow,
9+
tfl_direct,
10+
tfl_tensor
11+
)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_controlflow
6+
"""
7+
8+
import copy
9+
10+
from tf2onnx.handler import tfl_op
11+
from tf2onnx import utils
12+
import numpy as np
13+
14+
from tf2onnx.tf_loader import find_function
15+
from tf2onnx.onnx_opset.controlflow import parameter_binding, inline_subgraph
16+
from onnx.onnx_pb import TensorProto
17+
18+
@tfl_op(["TFL_WHILE"])
19+
class TflWhile:
20+
@classmethod
21+
def version_7(cls, ctx, node, **kwargs):
22+
tfl_while_inputs = node.input
23+
output_shapes = node.output_shapes
24+
output_dtypes = node.output_dtypes
25+
output_names = node.output
26+
27+
cond_name = node.get_attr_str("cond_subgraph_index")
28+
cond_graph = find_function(cond_name)
29+
cond_graph.parent_graph = ctx
30+
31+
body_name = node.get_attr_str("body_subgraph_index")
32+
body = find_function(body_name)
33+
body.parent_graph = ctx
34+
35+
ctx.remove_node(node.name)
36+
37+
cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
38+
cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)
39+
40+
max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))
41+
42+
loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
43+
output_count=len(output_shapes), name=node.name + "_loop",
44+
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True)
45+
46+
output_map = dict(zip(output_names, loop_node.output))
47+
48+
# shift output consumers
49+
for k, v in output_map.items():
50+
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
51+
52+
body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph)
53+
54+
loop_node.set_body_graph_as_attr("body", body)
55+
56+
def wire_tfl_while_body(g, loop_node_inputs, output_shapes,
57+
output_dtypes, cond_graph):
58+
"""Wire subgraph graph into main."""
59+
60+
g = copy.deepcopy(g)
61+
62+
# onnx will pass in cond as argument
63+
iter_node = g.make_node("Placeholder", [], name=utils.make_name("iteration_num"),
64+
output_count=1, dtypes=[TensorProto.INT64], shapes=[[]])
65+
cond_node = g.make_node("Placeholder", [], name=utils.make_name("cond"),
66+
output_count=1, dtypes=[TensorProto.BOOL], shapes=[[]])
67+
cond_binding = parameter_binding(cond_graph, g.outputs)
68+
69+
# in onnx the body inputs are: index, cond, [loop_vars]
70+
g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
71+
# tell graph lib to keep inputs in order
72+
g._order_sensitive_inputs = \
73+
[g.get_node_by_output(name) for name in g.func_inputs] # pylint: disable=protected-access
74+
75+
for p, c in zip(loop_node_inputs, g.func_inputs):
76+
shape = p.output_shapes[0]
77+
g.set_shape(c, shape)
78+
79+
cond_outputs = inline_subgraph(g, cond_graph, "cond__", cond_binding)
80+
81+
g.outputs = [cond_outputs[0]] + g.outputs
82+
return g
83+
84+
@tfl_op(["TFL_IF"], tf_op="If")
85+
class TflIfOp:
86+
@classmethod
87+
def to_tf(cls, ctx, node, **kwargs):
88+
node.attr["then_branch"] = node.attr["then_subgraph_index"]
89+
del node.attr["then_subgraph_index"]
90+
node.attr["else_branch"] = node.attr["else_subgraph_index"]
91+
del node.attr["else_subgraph_index"]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_direct
6+
"""
7+
8+
from tf2onnx.handler import tfl_op
9+
10+
@tfl_op("TFL_ABS", tf_op="Abs")
11+
@tfl_op("TFL_CEIL", tf_op="Ceil")
12+
@tfl_op("TFL_COS", tf_op="Cos")
13+
@tfl_op("TFL_ELU", tf_op="Elu")
14+
@tfl_op("TFL_EQUAL", tf_op="Equal")
15+
@tfl_op("TFL_EXP", tf_op="Exp")
16+
@tfl_op("TFL_FLOOR", tf_op="Floor")
17+
@tfl_op("TFL_FLOOR_DIV", tf_op="FloorDiv")
18+
@tfl_op("TFL_FLOOR_MOD", tf_op="FloorMod")
19+
@tfl_op("TFL_GREATER", tf_op="Greater")
20+
@tfl_op("TFL_GREATER_EQUAL", tf_op="GreaterEqual")
21+
@tfl_op("TFL_LESS", tf_op="Less")
22+
@tfl_op("TFL_LESS_EQUAL", tf_op="LessEqual")
23+
@tfl_op("TFL_LOG", tf_op="Log")
24+
@tfl_op("TFL_LOG_SOFTMAX", tf_op="LogSoftmax")
25+
@tfl_op("TFL_LOGICAL_AND", tf_op="LogicalAnd")
26+
@tfl_op("TFL_LOGICAL_NOT", tf_op="LogicalNot")
27+
@tfl_op("TFL_LOGICAL_OR", tf_op="LogicalOr")
28+
@tfl_op("TFL_MATRIX_DIAG", tf_op="MatrixDiag")
29+
@tfl_op("TFL_MATRIX_SET_DIAG", tf_op="MatrixSetDiag")
30+
@tfl_op("TFL_MAXIMUM", tf_op="Maximum")
31+
@tfl_op("TFL_MINIMUM", tf_op="Minimum")
32+
@tfl_op("TFL_NEG", tf_op="Neg")
33+
@tfl_op("TFL_NOT_EQUAL", tf_op="NotEqual")
34+
@tfl_op("TFL_POW", tf_op="Pow")
35+
@tfl_op("TFL_RANK", tf_op="Rank")
36+
@tfl_op("TFL_RELU", tf_op="Relu")
37+
@tfl_op("TFL_RELU6", tf_op="Relu6")
38+
@tfl_op("TFL_ROUND", tf_op="Round")
39+
@tfl_op("TFL_RSQRT", tf_op="Rsqrt")
40+
@tfl_op("TFL_SELECT", tf_op="Select")
41+
@tfl_op("TFL_SELECT_V2", tf_op="SelectV2")
42+
@tfl_op("TFL_SIN", tf_op="Sin")
43+
@tfl_op("TFL_SQRT", tf_op="Sqrt")
44+
@tfl_op("TFL_SQUARE", tf_op="Square")
45+
@tfl_op("TFL_SQUARED_DIFFERENCE", tf_op="SquaredDifference")
46+
@tfl_op("TFL_TANH", tf_op="Tanh")
47+
@tfl_op("TFL_WHERE", tf_op="Where")
48+
@tfl_op("TFL_ZEROS_LIKE", tf_op="ZerosLike")
49+
@tfl_op("TFL_FILL", tf_op="Fill")
50+
@tfl_op("TFL_GATHER_ND", tf_op="GatherNd")
51+
@tfl_op("TFL_PAD", tf_op="Pad")
52+
@tfl_op("TFL_REVERSE_V2", tf_op="ReverseV2")
53+
@tfl_op("TFL_SCATTER_ND", tf_op="ScatterNd")
54+
@tfl_op("TFL_SEGMENT_SUM", tf_op="SegmentSum")
55+
@tfl_op("TFL_SHAPE", tf_op="Shape")
56+
@tfl_op("TFL_SLICE", tf_op="Slice")
57+
@tfl_op("TFL_SQUEEZE", tf_op="Squeeze")
58+
@tfl_op("TFL_TILE", tf_op="Tile")
59+
@tfl_op("TFL_EXPAND_DIMS", tf_op="ExpandDims")
60+
@tfl_op("TFL_TRANSPOSE", tf_op="Transpose")
61+
@tfl_op("TFL_UNPACK", tf_op="Unpack")
62+
@tfl_op("TFL_ADD_N", tf_op="AddN")
63+
@tfl_op("TFL_ONE_HOT", tf_op="OneHot")
64+
@tfl_op("TFL_DEPTH_TO_SPACE", tf_op="DepthToSpace")
65+
@tfl_op("TFL_ARG_MIN", tf_op="ArgMin")
66+
@tfl_op("TFL_ARG_MAX", tf_op="ArgMax")
67+
@tfl_op("TFL_NON_MAX_SUPPRESSION_V5", tf_op="NonMaxSuppressionV5")
68+
@tfl_op("TFL_RESIZE_NEAREST_NEIGHBOR", tf_op="ResizeNearestNeighbor")
69+
@tfl_op("TFL_LEAKY_RELU", tf_op="LeakyRelu")
70+
@tfl_op("TFL_STRIDED_SLICE", tf_op="StridedSlice")
71+
@tfl_op("TFL_MEAN", tf_op="Mean")
72+
@tfl_op("TFL_SUM", tf_op="Sum")
73+
@tfl_op("TFL_MIRROR_PAD", tf_op="MirrorPad")
74+
@tfl_op("TFL_RESIZE_BILINEAR", tf_op="ResizeBilinear")
75+
@tfl_op("TFL_REVERSE_SEQUENCE", tf_op="ReverseSequence")
76+
@tfl_op("TFL_SPARSE_TO_DENSE", tf_op="SparseToDense")
77+
@tfl_op("TFL_CUMSUM", tf_op="Cumsum")
78+
class TflDirectOp:
79+
@classmethod
80+
def to_tf(cls, ctx, node, **kwargs):
81+
pass
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tfl_math
6+
"""
7+
8+
import logging
9+
import numpy as np
10+
from tf2onnx.handler import tfl_op
11+
from tf2onnx import utils
12+
13+
logger = logging.getLogger(__name__)
14+
15+
def separate_fused_activation_function(ctx, node):
16+
activation_fn = node.attr['fused_activation_function'].s
17+
del node.attr['fused_activation_function']
18+
if activation_fn == b'RELU':
19+
ctx.insert_new_node_on_output("Relu", node.output[0])
20+
elif activation_fn == b'RELU6':
21+
new_node = ctx.insert_new_node_on_output("Relu6", node.output[0])
22+
new_node.skip_conversion = False
23+
elif activation_fn == b'TANH':
24+
ctx.insert_new_node_on_output("Tanh", node.output[0])
25+
else:
26+
# TODO: SIGN_BIT and RELU_N1_TO_1 not supported yet
27+
utils.make_sure(activation_fn == b'NONE', "Unsupported fused activation function %s on node %s",
28+
activation_fn, node.name)
29+
30+
@tfl_op(["TFL_ADD"], tf_op="Add")
31+
class TflAdd:
32+
@classmethod
33+
def to_tf(cls, ctx, node, **kwargs):
34+
separate_fused_activation_function(ctx, node)
35+
36+
@tfl_op(["TFL_SUB"], tf_op="Sub")
37+
class TflSub:
38+
@classmethod
39+
def to_tf(cls, ctx, node, **kwargs):
40+
separate_fused_activation_function(ctx, node)
41+
42+
@tfl_op(["TFL_MUL"], tf_op="Mul")
43+
class TflMul:
44+
@classmethod
45+
def to_tf(cls, ctx, node, **kwargs):
46+
separate_fused_activation_function(ctx, node)
47+
48+
@tfl_op(["TFL_DIV"], tf_op="Div")
49+
class TflDiv:
50+
@classmethod
51+
def to_tf(cls, ctx, node, **kwargs):
52+
separate_fused_activation_function(ctx, node)
53+
54+
@tfl_op(["TFL_LOGISTIC"], tf_op="Sigmoid")
55+
class TflLogistic:
56+
@classmethod
57+
def to_tf(cls, ctx, node, **kwargs):
58+
pass
59+
60+
@tfl_op(["TFL_REDUCE_MAX"], tf_op="Max")
61+
@tfl_op(["TFL_REDUCE_ANY"], tf_op="Any")
62+
@tfl_op(["TFL_REDUCE_PROD"], tf_op="Prod")
63+
class TflReduceOp:
64+
@classmethod
65+
def to_tf(cls, ctx, node, **kwargs):
66+
pass
67+
68+
@tfl_op(["TFL_LOCAL_RESPONSE_NORMALIZATION"], tf_op="LRN")
69+
class TFlLocalResponseNormalizationOp:
70+
@classmethod
71+
def to_tf(cls, ctx, node, **kwargs):
72+
node.attr["depth_radius"] = node.attr["radius"]
73+
del node.attr["radius"]
74+
75+
@tfl_op(["TFL_RANGE"], tf_op="Range")
76+
class TflRangeOp:
77+
@classmethod
78+
def to_tf(cls, ctx, node, **kwargs):
79+
node.set_attr("Tidx", ctx.get_dtype(node.output[0]))
80+
81+
@tfl_op(["TFL_QUANTIZE"], onnx_op="QuantizeLinear")
82+
class TflQuantizeOp:
83+
@classmethod
84+
def version_10(cls, ctx, node, **kwargs):
85+
scale = node.get_attr_value('scale')
86+
zero_point = node.get_attr_value('zero_point')
87+
axis = node.get_attr_value('quantized_dimension')
88+
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0]))
89+
if len(scale) > 1 or len(zero_point) > 1:
90+
node.set_attr("axis", axis)
91+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
92+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
93+
ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
94+
del node.attr["scale"]
95+
del node.attr["zero_point"]
96+
del node.attr["quantized_dimension"]
97+
98+
@tfl_op(["TFL_DEQUANTIZE"], onnx_op="DequantizeLinear")
99+
class TflDequantizeOp:
100+
@classmethod
101+
def version_10(cls, ctx, node, **kwargs):
102+
scale = node.get_attr_value('scale')
103+
zero_point = node.get_attr_value('zero_point')
104+
axis = node.get_attr_value('quantized_dimension')
105+
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))
106+
if len(scale) > 1 or len(zero_point) > 1:
107+
utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization")
108+
node.set_attr("axis", axis)
109+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale, dtype=np.float32))
110+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point, dtype=np_q_type))
111+
else:
112+
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
113+
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
114+
ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
115+
del node.attr["scale"]
116+
del node.attr["zero_point"]
117+
del node.attr["quantized_dimension"]
118+
119+
def dynamic_quantize_inputs(ctx, node):
120+
if ctx.opset < 11:
121+
logger.warning("Opset 11 is required for asymmetric_quantize_inputs of node %s", node.name)
122+
return
123+
for i in range(len(node.input)):
124+
# Don't quantize inputs that are already quantized
125+
if node.inputs[i].type in ["DequantizeLinear", "TFL_DEQUANTIZE"]:
126+
continue
127+
dyn_quant = ctx.make_node("DynamicQuantizeLinear", [node.input[i]], output_count=3, op_name_scope=node.name)
128+
dyn_quant.skip_conversion = True
129+
dequant = ctx.make_node("DequantizeLinear", dyn_quant.output, op_name_scope=node.name)
130+
dequant.skip_conversion = True
131+
ctx.replace_input(node, node.input[i], dequant.output[0], input_index=i)
132+
133+
@tfl_op(["TFL_FULLY_CONNECTED"])
134+
class TflFullyConnectedOp:
135+
@classmethod
136+
def to_tf(cls, ctx, node, **kwargs):
137+
separate_fused_activation_function(ctx, node)
138+
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
139+
"Only default weights format supported for fully connected op")
140+
utils.make_sure(node.attr['keep_num_dims'].i == 0,
141+
"Only keep_num_dims=False supported for fully connected op")
142+
if node.attr['asymmetric_quantize_inputs'].i == 1:
143+
dynamic_quantize_inputs(ctx, node)
144+
145+
transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1],
146+
name=None, input_index=1, perm=[1, 0])
147+
transpose_node.skip_conversion = True
148+
node.set_attr("transpose_a", 0)
149+
node.set_attr("transpose_b", 0)
150+
node.type = "MatMul"
151+
152+
if len(node.input) == 3:
153+
# FIXME: Add a test for this
154+
bias_inp = node.input[2]
155+
ctx.replace_inputs(node, node.input[:2])
156+
add_node = ctx.insert_new_node_on_output("Add", node.output[0], inputs=[node.output[0], bias_inp])
157+
add_node.skip_conversion = True
158+
159+
del node.attr["weights_format"]
160+
del node.attr["keep_num_dims"]
161+
del node.attr["asymmetric_quantize_inputs"]
162+
163+
@tfl_op(["TFL_SOFTMAX"], tf_op="Softmax")
164+
class TFlSoftmaxOp:
165+
@classmethod
166+
def to_tf(cls, ctx, node, **kwargs):
167+
beta = node.get_attr_value("beta")
168+
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
169+
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
170+
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])

0 commit comments

Comments
 (0)