21
21
)
22
22
from executorch .exir .dialects ._ops import ops as exir_ops
23
23
from executorch .exir .graph_module import get_control_flow_submodules
24
+ from torch ._export .utils import is_buffer , is_lifted_tensor_constant , is_param
24
25
from torch .export .exported_program import ExportedProgram
25
26
from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
26
27
27
28
29
+ def is_param_node (exp_prog : ExportedProgram , node : torch .fx .Node ) -> bool :
30
+ return (
31
+ is_param (exp_prog , node )
32
+ or is_buffer (exp_prog , node )
33
+ or is_lifted_tensor_constant (exp_prog , node )
34
+ )
35
+
36
+
37
+ def get_total_num_ops_in_ep (edge_programs , supported_ops ):
38
+ total_number_of_ops = 0
39
+ for edge_program in edge_programs .values ():
40
+ for partitioned_program in edge_program :
41
+ for node in partitioned_program .graph .nodes :
42
+ if node .op == "call_function" :
43
+ if node .target in supported_ops :
44
+ total_number_of_ops += 1
45
+ return total_number_of_ops
46
+
47
+
28
48
def _preprocess_multimethod (
29
49
edge_programs : Dict [str , List [ExportedProgram ]],
30
50
compile_specs : Dict [str , List [List [CompileSpec ]]],
@@ -37,13 +57,7 @@ def _preprocess_multimethod(
37
57
in testing for a partitioner which tags different partitions for different backends
38
58
to be lowered to
39
59
"""
40
- total_number_of_ops = 0
41
- for edge_program in edge_programs .values ():
42
- for partitioned_program in edge_program :
43
- for node in partitioned_program .graph .nodes :
44
- if node .op == "call_function" :
45
- if node .target in supported_ops :
46
- total_number_of_ops += 1
60
+ total_number_of_ops = get_total_num_ops_in_ep (edge_programs , supported_ops )
47
61
all_processed_results = {key : [] for key in edge_programs .keys ()}
48
62
49
63
for method_name , partitioned_programs in edge_programs .items ():
@@ -67,6 +81,8 @@ def _preprocess_multimethod(
67
81
raise RuntimeError (
68
82
f"{ node .op } { node .target .__name__ } is not supported in backend { backend_name } "
69
83
)
84
+ if is_param_node (partitioned_program , node ):
85
+ processed_bytes += f"CONST{ node .name } :"
70
86
71
87
processed_bytes += "#"
72
88
for cs in compile_spec_for_partition :
@@ -171,14 +187,30 @@ def preprocess_multimethod(
171
187
172
188
173
189
class AddSinOperatorSupport (OperatorSupportBase ):
190
+ def __init__ (self , original_program ):
191
+ self .original_program = original_program
192
+ super ().__init__ ()
193
+
174
194
def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
175
- return node . op == "call_function" and node . target in [
195
+ supported_targets = [
176
196
exir_ops .edge .aten .add .Tensor ,
177
197
exir_ops .edge .aten .sin .default ,
178
198
]
199
+ if node .op == "call_function" and node .target in supported_targets :
200
+ return True
201
+
202
+ if node .op == "placeholder" and is_param_node (self .original_program , node ):
203
+ for user in node .users .keys ():
204
+ if user .target in supported_targets :
205
+ return True
206
+ return False
179
207
180
208
181
209
class SubCosOperatorSupport (OperatorSupportBase ):
210
+ def __init__ (self , original_program ):
211
+ self .original_program = original_program
212
+ super ().__init__ ()
213
+
182
214
def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
183
215
return node .op == "call_function" and node .target in [
184
216
exir_ops .edge .aten .sub .Tensor ,
@@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
199
231
"""
200
232
201
233
def __init__ (self ) -> None :
202
- self .add_sin_support = any_chain (AddSinOperatorSupport ())
203
- self .add_sin_backend_id = FirstBackendWithPreprocessAll .__name__
204
-
205
- self .sub_cos_support = any_chain (SubCosOperatorSupport ())
206
234
self .sub_cos_backend_id = SecondBackendWithPreprocessAll .__name__
235
+ self .add_sin_backend_id = FirstBackendWithPreprocessAll .__name__
207
236
208
237
def _partition_graph_module (
209
238
self ,
@@ -260,6 +289,8 @@ def _partition_graph_module(
260
289
return partition_tags , start_idx_for_submodules
261
290
262
291
def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
292
+ self .add_sin_support = any_chain (AddSinOperatorSupport (exported_program ))
293
+ self .sub_cos_support = any_chain (SubCosOperatorSupport (exported_program ))
263
294
partition_tags , _ = self ._partition_graph_module (exported_program .graph_module )
264
295
return PartitionResult (
265
296
tagged_exported_program = exported_program , partition_tags = partition_tags
0 commit comments