5
5
6
6
# pyre-unsafe
7
7
8
- from typing import Callable , Dict
8
+ from itertools import chain
9
+ from typing import Callable , cast , Dict , Iterator , Set
9
10
10
11
import torch
11
12
from executorch .backends .arm ._passes .arm_pass_utils import create_node
17
18
18
19
from executorch .exir .pass_base import ExportPass , PassResult
19
20
from torch .fx import GraphModule
20
-
21
+ from torch . fx . node import Node
21
22
from torch .library import impl , Library
22
23
23
24
lib = Library ("tosa" , "DEF" )
@@ -32,15 +33,13 @@ def _table_impl(*args, **kwargs): # pyre-ignore
32
33
return args [0 ].to (dtype = torch .int32 )
33
34
34
35
35
- class InsertTableOpsPass ( ExportPass ) :
36
+ class TableOps :
36
37
"""
37
- For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
38
- edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
39
- When lowering the _table node target_str will be used to find the corresponding torch operator
40
- which will be used to produce the table values in operators/op_table.py.
38
+ Helper class for finding the corresponding table operator for a given Node.
41
39
"""
42
40
43
- table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
41
+ # Targets that follow a straigtforward one-to-one mapping to their table op
42
+ unary_table_ops : Dict [EdgeOpOverload , Callable [[torch .Tensor ], torch .Tensor ]] = {
44
43
exir_ops .edge .aten .ceil .default : torch .ceil ,
45
44
exir_ops .edge .aten .exp .default : torch .exp ,
46
45
exir_ops .edge .aten .floor .default : torch .floor ,
@@ -53,9 +52,52 @@ class InsertTableOpsPass(ExportPass):
53
52
exir_ops .edge .aten .hardswish .default : torch .nn .functional .hardswish ,
54
53
}
55
54
55
+ # Targets that must be treated explicitly
56
+ special_table_ops : Set [EdgeOpOverload ] = {
57
+ exir_ops .edge .aten .pow .Tensor_Scalar ,
58
+ }
59
+
60
+ def __init__ (self , exported_program : ExportedProgram ):
61
+ self .exported_program = exported_program
62
+
63
+ def __contains__ (self , node : Node ) -> bool :
64
+ return (
65
+ node .target in self .unary_table_ops or node .target in self .special_table_ops
66
+ )
67
+
68
+ def __getitem__ (self , node : Node ):
69
+ target = cast (EdgeOpOverload , node .target )
70
+ if target in self .unary_table_ops :
71
+ return self .unary_table_ops [target ]
72
+ elif target in self .special_table_ops :
73
+ match target :
74
+ case exir_ops .edge .aten .pow .Tensor_Scalar :
75
+ # Exponent is a constant. Embed it into a lambda.
76
+ exp = cast (int , node .args [1 ])
77
+ return lambda x : torch .pow (x , exp ).flatten ()
78
+ case _:
79
+ # Op must be handled if it's inside self.special_ops
80
+ raise AssertionError ("Unhandled table operation" )
81
+ else :
82
+ raise KeyError ("Table op for {target} does not exist" )
83
+
84
+ @staticmethod
85
+ def included_ops () -> Iterator [EdgeOpOverload ]:
86
+ return chain (TableOps .unary_table_ops , TableOps .special_table_ops )
87
+
88
+
89
+ class InsertTableOpsPass (ExportPass ):
90
+ """
91
+ For ops in self.table_ops they need to be serialized as a TOSA TABLE. This pass replaces these
92
+ edge ops with a tosa._table(input: Tensor, target_str: str) where target_str == str(node.target).
93
+ When lowering the _table node target_str will be used to find the corresponding torch operator
94
+ which will be used to produce the table values in operators/op_table.py.
95
+ """
96
+
56
97
def __init__ (self , exported_program : ExportedProgram ) -> None :
57
98
super ().__init__ ()
58
99
self .exported_program = exported_program
100
+ self .table_ops = TableOps (exported_program )
59
101
60
102
def register_buffer (self , buffer_name : str , buffer : torch .Tensor ) -> None :
61
103
"""
@@ -166,7 +208,7 @@ def generate_table_values(
166
208
def call (self , graph_module : GraphModule ) -> PassResult :
167
209
modified = False
168
210
for node in graph_module .graph .nodes :
169
- if node .op != "call_function" or node . target not in self .table_ops :
211
+ if node .op != "call_function" or node not in self .table_ops :
170
212
continue
171
213
input_qparams = node .meta ["input_qparams" ]
172
214
output_qparams = node .meta ["output_qparams" ]
@@ -186,7 +228,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
186
228
187
229
# Generate table buffer and how much to lshift the table output.
188
230
buffer , lshift = self .generate_table_values (
189
- torch_op = self .table_ops [node . target ],
231
+ torch_op = self .table_ops [node ],
190
232
in_quantargs = input_qparams [0 ],
191
233
out_quantargs = output_qparams [0 ],
192
234
)
@@ -207,7 +249,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
207
249
output_node = rescale_node
208
250
209
251
node .replace_all_uses_with (output_node )
252
+
210
253
graph_module .graph .erase_node (node )
254
+
211
255
output_node .meta ["input_qparams" ] = input_qparams
212
256
output_node .meta ["output_qparams" ] = output_qparams
213
257
modified = True
0 commit comments