7
7
8
8
# pyre-unsafe
9
9
from executorch .backends .arm ._passes import (
10
+ AddBiasPass ,
10
11
AnnotateChannelsLastDimOrder ,
11
12
AnnotateDecomposedMatmulPass ,
12
13
BroadcastArgsPass ,
14
+ CastBoolToInt8Pass ,
13
15
CastInt64BuffersToInt32Pass ,
14
16
CastToInt32Pass ,
15
17
ComputeConstantOpsAOT ,
23
25
ConvertSplitToSlicePass ,
24
26
ConvertSqueezesToViewPass ,
25
27
ConvertToClampPass ,
28
+ DecomposeAvgPool2d ,
26
29
DecomposeCosineSimilarityPass ,
27
30
DecomposeDivPass ,
28
31
DecomposeEmbeddingPass ,
29
32
DecomposeGeluPass ,
33
+ DecomposeGroupedConv ,
30
34
DecomposeGroupNormPass ,
31
35
DecomposeLayerNormPass ,
32
36
DecomposeLeakyReLUPass ,
35
39
DecomposeMaxPool2DPass ,
36
40
DecomposeMeanDimPass ,
37
41
DecomposeNotEqualPass ,
42
+ DecomposeRoundPass ,
38
43
DecomposeSelectPass ,
39
44
DecomposeSiluPass ,
40
45
DecomposeSoftmaxPass ,
63
68
UnsqueezeBeforeRepeatPass ,
64
69
UnsqueezeScalarPlaceholdersPass ,
65
70
)
66
-
67
71
from executorch .backends .arm .tosa_specification import (
68
72
TosaLoweringContext ,
69
73
TosaSpecification ,
@@ -105,6 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
105
109
if self .tosa_spec .is_U55_subset :
106
110
self .add_pass (CastToInt32Pass ())
107
111
112
+ self .add_pass (CastBoolToInt8Pass ())
108
113
self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
109
114
self .add_pass (AnnotateDecomposedMatmulPass ())
110
115
self .add_pass (QuantizeOperatorArguments ())
@@ -115,8 +120,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115
120
if self .tosa_spec .is_U55_subset :
116
121
self .add_pass (BroadcastArgsPass ())
117
122
self .add_pass (DecomposeLinearPass ())
123
+ self .add_pass (DecomposeAvgPool2d ())
118
124
self .add_pass (ComputeConstantOpsAOT (exported_program ))
119
125
126
+ self .add_pass (DecomposeGroupedConv ())
120
127
self .add_pass (RemoveClonePass ())
121
128
self .add_pass (SizeAdjustConv2DPass ())
122
129
self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -130,6 +137,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130
137
131
138
self .add_pass (FuseViewCopyTransform ())
132
139
self .add_pass (FuseConstantArgsPass (exported_program ))
140
+ self .add_pass (AddBiasPass (exported_program ))
133
141
134
142
self .add_pass (InsertTableOpsPass (exported_program ))
135
143
self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
@@ -139,8 +147,10 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139
147
return self ._transform (exported_program .graph_module )
140
148
141
149
def _tosa_080_MI_pipeline (self , exported_program : ExportedProgram ) -> GraphModule :
150
+ self .add_pass (DecomposeRoundPass ())
142
151
self .add_pass (DecomposeSqrtPass ())
143
152
self .add_pass (ConvertIntPowToMuls ())
153
+ self .add_pass (CastBoolToInt8Pass ())
144
154
self .add_pass (ReplaceScalarWithTensorArgPassTOSAMI ())
145
155
self .add_pass (DecomposeEmbeddingPass ())
146
156
self .add_pass (FuseQuantizedActivationPass ())
@@ -172,8 +182,10 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
172
182
self .add_pass (RetraceFoldedDtypesPass ())
173
183
self .add_pass (UnsqueezeScalarPlaceholdersPass (exported_program ))
174
184
self .add_pass (MatchArgRanksPass (exported_program ))
185
+ self .add_pass (DecomposeAvgPool2d ())
175
186
self .add_pass (ComputeConstantOpsAOT (exported_program ))
176
187
188
+ self .add_pass (DecomposeGroupedConv ())
177
189
self .add_pass (RemoveClonePass ())
178
190
self .add_pass (SizeAdjustConv2DPass ())
179
191
self .add_pass (ConvertExpandCopyToRepeatPass ())
@@ -187,6 +199,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
187
199
188
200
self .add_pass (FuseViewCopyTransform ())
189
201
self .add_pass (FuseConstantArgsPass (exported_program ))
202
+ self .add_pass (AddBiasPass (exported_program ))
190
203
self .add_pass (InsertTableOpsPass (exported_program ))
191
204
self .add_pass (FuseEqualPlaceholdersPass (exported_program ))
192
205
self .add_pass (AnnotateChannelsLastDimOrder ())
@@ -219,6 +232,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
219
232
self .add_pass (InsertCastForOpsWithInt64InputPass ())
220
233
self .add_pass (DecomposeEmbeddingPass ())
221
234
self .add_pass (DecomposeScaledDotProductAttention ())
235
+ self .add_pass (DecomposeRoundPass ())
236
+ self .add_pass (CastBoolToInt8Pass ())
222
237
self .add_pass (ReplaceScalarWithTensorArgPassTOSABI ())
223
238
self .add_pass (ScalarsToAttributePass ())
224
239
self .add_pass (DecomposeGroupNormPass ())
@@ -232,6 +247,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
232
247
self .add_pass (DecomposeLinearVectorNormPass ())
233
248
self .add_pass (DecomposeSqrtPass ())
234
249
self .add_pass (DecomposeSiluPass ())
250
+ self .add_pass (DecomposeAvgPool2d ())
235
251
236
252
if self .tosa_spec .is_U55_subset :
237
253
# Numerically stable softmax uses amax which is not supported on Ethos-U55
0 commit comments