5
5
6
6
# pyre-unsafe
7
7
8
- from typing import List
8
+ from typing import Any , List
9
9
10
10
import executorch .backends .arm .tosa_quant_utils as tqutils
11
11
import executorch .backends .arm .tosa_utils as tutils
12
12
13
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
14
13
from executorch .backends .arm .operators .node_visitor import (
15
14
NodeVisitor ,
16
15
register_node_visitor ,
@@ -34,10 +33,13 @@ def __init__(self, *args):
34
33
def define_node (
35
34
self ,
36
35
node : Node ,
37
- tosa_graph : ts . TosaSerializer ,
36
+ tosa_graph : Any ,
38
37
inputs : List [TosaArg ],
39
38
output : TosaArg ,
40
39
) -> None :
40
+
41
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
42
+
41
43
# Specification (0.80) states that input and output types
42
44
# should all be the same
43
45
if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
@@ -58,7 +60,7 @@ def define_node(
58
60
if len (inputs [0 ].shape ) > len (inputs [1 ].shape )
59
61
else inputs [1 ].dim_order
60
62
)
61
-
63
+ scale_back = 1.0
62
64
if inputs [0 ].dtype == ts .DType .INT8 :
63
65
rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
64
66
tosa_graph , inputs , node
@@ -90,7 +92,9 @@ def define_node(
90
92
if output .dtype == ts .DType .INT8 :
91
93
# Scale output back to 8 bit
92
94
# pyre-ignore
93
- tqutils .insert_rescale_op_to_int8 (tosa_graph , add_output , scale_back , node ) # type: ignore[possibly-undefined]
95
+ tqutils .insert_rescale_op_to_int8 (
96
+ tosa_graph , add_output , scale_back , node
97
+ ) # type: ignore[possibly-undefined]
94
98
95
99
96
100
@register_node_visitor
@@ -107,10 +111,13 @@ def __init__(self, *args):
107
111
def define_node (
108
112
self ,
109
113
node : Node ,
110
- tosa_graph : ts . TosaSerializer ,
114
+ tosa_graph : Any ,
111
115
inputs : List [TosaArg ],
112
116
output : TosaArg ,
113
117
) -> None :
118
+
119
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
120
+
114
121
# Specification (0.80) states that input and output types
115
122
# should all be the same
116
123
if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
@@ -130,7 +137,7 @@ def define_node(
130
137
f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
131
138
)
132
139
133
- input1 , input2 = tutils . reshape_for_broadcast ( tosa_graph , inputs )
140
+ input1 , input2 = inputs
134
141
135
142
# MI lowering
136
143
tosa_graph .addOperator (
@@ -139,3 +146,122 @@ def define_node(
139
146
[output .name ],
140
147
None ,
141
148
)
149
+
150
+
151
+ @register_node_visitor
152
+ class AddVisitor_INT (NodeVisitor ):
153
+ target = "aten.add.Tensor"
154
+
155
+ tosa_specs = [
156
+ TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
157
+ ]
158
+
159
+ def __init__ (self , * args ):
160
+ super ().__init__ (* args )
161
+
162
+ def define_node (
163
+ self ,
164
+ node : Node ,
165
+ tosa_graph : Any ,
166
+ inputs : List [TosaArg ],
167
+ output : TosaArg ,
168
+ ) -> None :
169
+
170
+ import serializer .tosa_serializer as ts # type: ignore
171
+
172
+ # Specification (1.0) states that input and output types
173
+ # should all be the same
174
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
175
+ raise TypeError (
176
+ f"All IO needs to have the same data type, got input 1: "
177
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
178
+ f"{ output .dtype } "
179
+ )
180
+ # Handle int8 (quantized) and int32
181
+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
182
+ if inputs [0 ].dtype not in supported_dtypes :
183
+ raise TypeError (
184
+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
185
+ )
186
+ scale_back = 1.0
187
+ if inputs [0 ].dtype == ts .DType .INT8 :
188
+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
189
+ tosa_graph , inputs , node , self .tosa_specs
190
+ )
191
+ else :
192
+ # input[0].dtype == ts.DType.INT32
193
+ # Non quantized input, natively support by TOSA.ADD
194
+ rescaled_inputs = inputs
195
+
196
+ if output .dtype == ts .DType .INT8 :
197
+ broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
198
+ add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
199
+ else :
200
+ # output.dtype == ts.DType.INT32
201
+ add_output = output
202
+
203
+ input1 , input2 = rescaled_inputs
204
+
205
+ # Do the INT32 Add
206
+ tosa_graph .addOperator (
207
+ ts .TosaOp .Op ().ADD ,
208
+ [input1 .name , input2 .name ],
209
+ [add_output .name ],
210
+ None ,
211
+ )
212
+
213
+ if output .dtype == ts .DType .INT8 :
214
+ # Scale output back to 8 bit
215
+ # pyre-ignore
216
+ tqutils .insert_rescale_op_to_int8 (
217
+ tosa_graph , add_output , scale_back , node , self .tosa_specs
218
+ ) # type: ignore[possibly-undefined]
219
+
220
+
221
+ @register_node_visitor
222
+ class AddVisitor_FP (AddVisitor_INT ):
223
+ # inheriting 'target' from INT class
224
+
225
+ tosa_specs = [TosaSpecification .create_from_string ("TOSA-1.0+FP" )]
226
+
227
+ def __init__ (self , * args ):
228
+ super ().__init__ (* args )
229
+
230
+ def define_node (
231
+ self ,
232
+ node : Node ,
233
+ tosa_graph : Any ,
234
+ inputs : List [TosaArg ],
235
+ output : TosaArg ,
236
+ ) -> None :
237
+
238
+ import serializer .tosa_serializer as ts # type: ignore
239
+
240
+ # Specification (1.0) states that input and output types
241
+ # should all be the same
242
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
243
+ raise TypeError (
244
+ f"All IO needs to have the same data type, got input 1: "
245
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
246
+ f"{ output .dtype } "
247
+ )
248
+
249
+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
250
+ # Call the inherited define_node for handling integers
251
+ super ().define_node (node , tosa_graph , inputs , output )
252
+ else :
253
+ # FP32 Add lowering
254
+ if inputs [0 ].dtype != ts .DType .FP32 :
255
+ raise TypeError (
256
+ f"Expected IO data type to be FP32, got { inputs [0 ].dtype } "
257
+ )
258
+
259
+ input1 , input2 = inputs
260
+
261
+ # FP lowering
262
+ tosa_graph .addOperator (
263
+ ts .TosaOp .Op ().ADD ,
264
+ [input1 .name , input2 .name ],
265
+ [output .name ],
266
+ None ,
267
+ )
0 commit comments