1
1
import logging
2
+ import operator
2
3
from typing import Callable , List , Optional , Set , Tuple
3
4
4
5
import torch
@@ -33,8 +34,7 @@ def __repr__(self):
33
34
34
35
35
36
class ComplexOpDetector :
36
- def __init__ (self , logger ):
37
- self .logger = logger
37
+ def __init__ (self ):
38
38
pass
39
39
40
40
def is_complex_dtype (self , node : Node ) -> bool :
@@ -45,15 +45,13 @@ def is_complex_dtype(self, node: Node) -> bool:
45
45
if hasattr (val , "dtype" ):
46
46
dtype = val .dtype
47
47
48
- self . logger .debug (f"dtype of node: { dtype } " )
48
+ logger .debug (f"dtype of node: { dtype } " )
49
49
return dtype in {torch .complex64 , torch .complex128 }
50
50
51
51
def node_include_in_subgraph (self , node : Node ) -> bool :
52
52
# Include only call_function ops on complex tensors
53
- self .logger .debug (f"node.op: { node .op } , node name: { node .name } " )
54
- self .logger .debug (f"is_complex_dtype: { self .is_complex_dtype (node )} " )
55
53
if node .op == "call_function" and self .is_complex_dtype (node ):
56
- self . logger .debug (
54
+ logger .debug (
57
55
f"node.op is added to subgraph: { node .op } , node name: { node .name } is complex"
58
56
)
59
57
return node .op == "call_function" and self .is_complex_dtype (node )
@@ -67,7 +65,7 @@ def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo:
67
65
if n in subgraph_nodes :
68
66
continue
69
67
subgraph_nodes .add (n )
70
- self . logger .debug (f"node { n .name } is added to subgraph" )
68
+ logger .debug (f"node { n .name } is added to subgraph" )
71
69
for inp in n .all_input_nodes :
72
70
if self .node_include_in_subgraph (inp ):
73
71
print ("node inp is added to stack:" , inp .name )
@@ -85,13 +83,12 @@ def find_complex_op_subgraphs(
85
83
complex_op_subgraphs : List [ComplexSubGraphInfo ] = []
86
84
for node in gm .graph .nodes :
87
85
if node .target == anchor_target :
88
- self .logger .debug (f"node.target { node .target } node.name: { node .name } " )
89
86
new_sub = self .subgraph_from_anchor (node )
90
87
# if any intersecting nodes between seen and sub.subgraph_nodes they should be merged
91
88
merged = False
92
89
for existing_sub in complex_op_subgraphs :
93
90
if set (existing_sub .subgraph_nodes ) & set (new_sub .subgraph_nodes ):
94
- self . logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
91
+ logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
95
92
# merge the two subgraphs
96
93
existing_sub .subgraph_nodes = list (
97
94
set (existing_sub .subgraph_nodes )
@@ -113,7 +110,7 @@ def find_complex_op_subgraphs(
113
110
def complex_graph_detection (
114
111
gm : GraphModule , settings : CompilationSettings
115
112
) -> List [ComplexSubGraphInfo ]:
116
- complex_op_detector = ComplexOpDetector (logger )
113
+ complex_op_detector = ComplexOpDetector ()
117
114
complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
118
115
gm , anchor_target = torch .ops .aten .view_as_real .default
119
116
)
@@ -174,17 +171,24 @@ def replace_input_node(self, input_node):
174
171
175
172
elif input_node .op == "get_attr" :
176
173
new_attr_name = input_node .target + "_reshaped"
177
- original_tensor = self .get_attr_tensor (input_node .target )
178
- stacked_tensor = torch .stack (
179
- [original_tensor .real , original_tensor .imag ], dim = - 1
180
- )
181
- self .gm .register_buffer (new_attr_name , stacked_tensor )
174
+ from torch ._subclasses .fake_tensor import unset_fake_temporarily
175
+
176
+ with unset_fake_temporarily ():
177
+ original_tensor = self .get_attr_tensor (input_node .target )
178
+ stacked_tensor = torch .stack (
179
+ [original_tensor .real , original_tensor .imag ], dim = - 1
180
+ )
181
+ self .gm .register_buffer (new_attr_name , stacked_tensor )
182
182
with self .gm .graph .inserting_after (input_node ):
183
183
new_node = self .gm .graph .get_attr (new_attr_name )
184
184
185
185
else :
186
- logger .debug (f"Unsupported node type: { input_node .op } " )
187
- logger .debug ("This node type does not need to replaced" )
186
+ logger .debug (
187
+ f"Unsupported node type in replacement of input node: { input_node .op } "
188
+ )
189
+ logger .debug (
190
+ "This complex subgraph inputnode type does not need to replaced"
191
+ )
188
192
189
193
input_node .replace_all_uses_with (new_node )
190
194
self .gm .graph .erase_node (input_node )
@@ -211,6 +215,8 @@ def rewrite_subgraph_nodes(self, subgraphs):
211
215
212
216
def match_complex_mul (
213
217
match : torch .fx .subgraph_rewriter .Match ,
218
+ original_graph ,
219
+ pattern_graph ,
214
220
) -> bool :
215
221
for original_node in match .nodes_map .values ():
216
222
if original_node .name == node .name :
@@ -230,10 +236,9 @@ def match_complex_mul(
230
236
self .gm .graph .erase_node (node )
231
237
else :
232
238
logger .debug (f"Unsupported node target: { node .target } " )
233
- logger .debug (f"This node type does not need to replaced" )
234
- if modified :
235
- self .gm .graph .lint ()
236
- self .gm .recompile ()
239
+ logger .debug (
240
+ "This complex subgraphnode type does not need to replaced"
241
+ )
237
242
238
243
if modified :
239
244
self .gm .graph .lint ()
@@ -256,16 +261,28 @@ def complex_mul_replacement() -> Tuple[
256
261
257
262
# Original pattern: torch.mul for complex tensors
258
263
def original_mul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
259
- return torch .mul (x , y )
264
+ return torch .ops . aten . mul . Tensor (x , y )
260
265
261
266
# Replacement function: manual complex multiplication on real/imag stacked tensors
262
267
def replacement (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
263
- x_real , x_imag = x [..., 0 ], x [..., 1 ]
264
- y_real , y_imag = y [..., 0 ], y [..., 1 ]
265
-
266
- real = x_real * y_real - x_imag * y_imag
267
- imag = x_real * y_imag + x_imag * y_real
268
-
269
- return torch .stack ((real , imag ), dim = - 1 )
268
+ x_real = torch .ops .aten .select .int (x , - 1 , 0 )
269
+ x_imag = torch .ops .aten .select .int (x , - 1 , 1 ) # x is reshape tensor
270
+ y_real , y_imag = y [..., 0 ], y [..., 1 ] # y is frozen param
271
+
272
+ real_part1 = torch .ops .aten .mul .Tensor (x_real , y_real )
273
+ real_part2 = torch .ops .aten .mul .Tensor (x_imag , y_imag )
274
+ real = torch .ops .aten .sub .Tensor (real_part1 , real_part2 )
275
+
276
+ imag_part1 = torch .ops .aten .mul .Tensor (x_real , y_imag )
277
+ imag_part2 = torch .ops .aten .mul .Tensor (x_imag , y_real )
278
+ imag = torch .ops .aten .add .Tensor (imag_part1 , imag_part2 )
279
+
280
+ return torch .ops .aten .cat .default (
281
+ [
282
+ torch .ops .aten .unsqueeze .default (real , - 1 ),
283
+ torch .ops .aten .unsqueeze .default (imag , - 1 ),
284
+ ],
285
+ dim = - 1 ,
286
+ )
270
287
271
288
return (original_mul , replacement )
0 commit comments