1
1
import logging
2
+ import operator
2
3
from typing import Callable , List , Optional , Set , Tuple
3
4
4
5
import torch
@@ -33,8 +34,7 @@ def __repr__(self):
33
34
34
35
35
36
class ComplexOpDetector :
36
- def __init__ (self , logger ):
37
- self .logger = logger
37
+ def __init__ (self ):
38
38
pass
39
39
40
40
def is_complex_dtype (self , node : Node ) -> bool :
@@ -45,15 +45,13 @@ def is_complex_dtype(self, node: Node) -> bool:
45
45
if hasattr (val , "dtype" ):
46
46
dtype = val .dtype
47
47
48
- self . logger .debug (f"dtype of node: { dtype } " )
48
+ logger .debug (f"dtype of node: { dtype } " )
49
49
return dtype in {torch .complex64 , torch .complex128 }
50
50
51
51
def node_include_in_subgraph (self , node : Node ) -> bool :
52
52
# Include only call_function ops on complex tensors
53
- self .logger .debug (f"node.op: { node .op } , node name: { node .name } " )
54
- self .logger .debug (f"is_complex_dtype: { self .is_complex_dtype (node )} " )
55
53
if node .op == "call_function" and self .is_complex_dtype (node ):
56
- self . logger .debug (
54
+ logger .debug (
57
55
f"node.op is added to subgraph: { node .op } , node name: { node .name } is complex"
58
56
)
59
57
return node .op == "call_function" and self .is_complex_dtype (node )
@@ -67,13 +65,11 @@ def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo:
67
65
if n in subgraph_nodes :
68
66
continue
69
67
subgraph_nodes .add (n )
70
- self . logger .debug (f"node { n .name } is added to subgraph" )
68
+ logger .debug (f"node { n .name } is added to subgraph" )
71
69
for inp in n .all_input_nodes :
72
70
if self .node_include_in_subgraph (inp ):
73
- print ("node inp is added to stack:" , inp .name )
74
71
stack .append (inp )
75
72
else :
76
- print ("node inp is not added to stack BUT INP:" , inp .name )
77
73
input_nodes .add (inp )
78
74
return ComplexSubGraphInfo (
79
75
[anchor_node ], list (subgraph_nodes ), list (input_nodes )
@@ -85,13 +81,12 @@ def find_complex_op_subgraphs(
85
81
complex_op_subgraphs : List [ComplexSubGraphInfo ] = []
86
82
for node in gm .graph .nodes :
87
83
if node .target == anchor_target :
88
- self .logger .debug (f"node.target { node .target } node.name: { node .name } " )
89
84
new_sub = self .subgraph_from_anchor (node )
90
85
# if any intersecting nodes between seen and sub.subgraph_nodes they should be merged
91
86
merged = False
92
87
for existing_sub in complex_op_subgraphs :
93
88
if set (existing_sub .subgraph_nodes ) & set (new_sub .subgraph_nodes ):
94
- self . logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
89
+ logger .debug (f"merging subgraphs { existing_sub } { new_sub } " )
95
90
# merge the two subgraphs
96
91
existing_sub .subgraph_nodes = list (
97
92
set (existing_sub .subgraph_nodes )
@@ -113,7 +108,7 @@ def find_complex_op_subgraphs(
113
108
def complex_graph_detection (
114
109
gm : GraphModule , settings : CompilationSettings
115
110
) -> List [ComplexSubGraphInfo ]:
116
- complex_op_detector = ComplexOpDetector (logger )
111
+ complex_op_detector = ComplexOpDetector ()
117
112
complex_subgraphs = complex_op_detector .find_complex_op_subgraphs (
118
113
gm , anchor_target = torch .ops .aten .view_as_real .default
119
114
)
@@ -174,17 +169,24 @@ def replace_input_node(self, input_node):
174
169
175
170
elif input_node .op == "get_attr" :
176
171
new_attr_name = input_node .target + "_reshaped"
177
- original_tensor = self .get_attr_tensor (input_node .target )
178
- stacked_tensor = torch .stack (
179
- [original_tensor .real , original_tensor .imag ], dim = - 1
180
- )
181
- self .gm .register_buffer (new_attr_name , stacked_tensor )
172
+ from torch ._subclasses .fake_tensor import unset_fake_temporarily
173
+
174
+ with unset_fake_temporarily ():
175
+ original_tensor = self .get_attr_tensor (input_node .target )
176
+ stacked_tensor = torch .stack (
177
+ [original_tensor .real , original_tensor .imag ], dim = - 1
178
+ )
179
+ self .gm .register_buffer (new_attr_name , stacked_tensor )
182
180
with self .gm .graph .inserting_after (input_node ):
183
181
new_node = self .gm .graph .get_attr (new_attr_name )
184
182
185
183
else :
186
- logger .debug (f"Unsupported node type: { input_node .op } " )
187
- logger .debug ("This node type does not need to replaced" )
184
+ logger .debug (
185
+ f"Unsupported node type in replacement of input node: { input_node .op } "
186
+ )
187
+ logger .debug (
188
+ "This complex subgraph inputnode type does not need to replaced"
189
+ )
188
190
189
191
input_node .replace_all_uses_with (new_node )
190
192
self .gm .graph .erase_node (input_node )
@@ -211,6 +213,8 @@ def rewrite_subgraph_nodes(self, subgraphs):
211
213
212
214
def match_complex_mul (
213
215
match : torch .fx .subgraph_rewriter .Match ,
216
+ original_graph ,
217
+ pattern_graph ,
214
218
) -> bool :
215
219
for original_node in match .nodes_map .values ():
216
220
if original_node .name == node .name :
@@ -230,10 +234,9 @@ def match_complex_mul(
230
234
self .gm .graph .erase_node (node )
231
235
else :
232
236
logger .debug (f"Unsupported node target: { node .target } " )
233
- logger .debug (f"This node type does not need to replaced" )
234
- if modified :
235
- self .gm .graph .lint ()
236
- self .gm .recompile ()
237
+ logger .debug (
238
+ "This complex subgraphnode type does not need to replaced"
239
+ )
237
240
238
241
if modified :
239
242
self .gm .graph .lint ()
@@ -256,16 +259,28 @@ def complex_mul_replacement() -> Tuple[
256
259
257
260
# Original pattern: torch.mul for complex tensors
258
261
def original_mul (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
259
- return torch .mul (x , y )
262
+ return torch .ops . aten . mul . Tensor (x , y )
260
263
261
264
# Replacement function: manual complex multiplication on real/imag stacked tensors
262
265
def replacement (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
263
- x_real , x_imag = x [..., 0 ], x [..., 1 ]
264
- y_real , y_imag = y [..., 0 ], y [..., 1 ]
265
-
266
- real = x_real * y_real - x_imag * y_imag
267
- imag = x_real * y_imag + x_imag * y_real
268
-
269
- return torch .stack ((real , imag ), dim = - 1 )
266
+ x_real = torch .ops .aten .select .int (x , - 1 , 0 )
267
+ x_imag = torch .ops .aten .select .int (x , - 1 , 1 ) # x is reshape tensor
268
+ y_real , y_imag = y [..., 0 ], y [..., 1 ] # y is frozen param
269
+
270
+ real_part1 = torch .ops .aten .mul .Tensor (x_real , y_real )
271
+ real_part2 = torch .ops .aten .mul .Tensor (x_imag , y_imag )
272
+ real = torch .ops .aten .sub .Tensor (real_part1 , real_part2 )
273
+
274
+ imag_part1 = torch .ops .aten .mul .Tensor (x_real , y_imag )
275
+ imag_part2 = torch .ops .aten .mul .Tensor (x_imag , y_real )
276
+ imag = torch .ops .aten .add .Tensor (imag_part1 , imag_part2 )
277
+
278
+ return torch .ops .aten .cat .default (
279
+ [
280
+ torch .ops .aten .unsqueeze .default (real , - 1 ),
281
+ torch .ops .aten .unsqueeze .default (imag , - 1 ),
282
+ ],
283
+ dim = - 1 ,
284
+ )
270
285
271
286
return (original_mul , replacement )
0 commit comments