Skip to content

Commit a380ed5

Browse files
committed
fix failures, changing getitem evaluator and adding lowering pass test case
1 parent f0c23c9 commit a380ed5

File tree

5 files changed

+152
-49
lines changed

5 files changed

+152
-49
lines changed

examples/distributed_inference/rotary_embedding.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,21 @@ def __init__(self, dim: int, seq_len: int):
8888
self.wo = nn.Linear(dim, dim)
8989
self.seq_len = seq_len
9090
self.n_parallel = 1
91-
self.register_buffer(
92-
"freqs_cis",
93-
self._precompute_freqs_cis(),
94-
persistent=True,
91+
theta = 10000.0
92+
self.freqs_cis = precompute_freqs_cis(
93+
self.dim, self.seq_len, theta, self.n_parallel
9594
)
95+
self.register_buffer("freqs_cis", self.freqs_cis, persistent=True)
9696
self.init_weights()
9797

9898
def init_weights(self):
9999
with torch.device(self.freqs_cis.device):
100-
self.freqs_cis = self._precompute_freqs_cis()
101-
102-
def _precompute_freqs_cis(self) -> torch.Tensor:
103-
theta = 10000.0
104-
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel)
100+
self.freqs_cis = self.freqs_cis
105101

106102
def forward(self, x):
107103
q = self.wq(x)
108104
k = self.wk(x)
109105
# calculate rotary embedding
110-
freqs_cis = self._precompute_freqs_cis().to(q.device)
106+
freqs_cis = self.freqs_cis.to(q.device)
111107
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis)
112108
return self.wo(q)

examples/distributed_inference/tensor_parallel_rotary_embedding.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,7 @@
3434

3535
logger.info("Torch-tensorrt compilation for rotary embedding")
3636

37-
# Compile the model
38-
# for single GPU let us first try without this optiob
39-
4037
model = torch.compile(model, backend="torch_tensorrt", options={"debug": True})
41-
# model = torch_tensorrt.compile(model, target_ir="torch_compile", options={
42-
# "debug": True,
43-
# })
4438

4539
for i in range(15):
4640
# seeding with dp_rank to ensure identical inputs for TP groups

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None)
2323
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS
2424

2525
# Getitem nodes can only be converted if their parent node also can
26-
return getitem_node.args[0] in DYNAMO_CONVERTERS
26+
return (
27+
getitem_node.args[0] in DYNAMO_CONVERTERS
28+
or getitem_node.args[0].op == "get_attr"
29+
)
2730

2831

2932
# TODO: Subsequent evaluators should be registered here with their own validators
@@ -43,7 +46,10 @@ def generic_evaluator(
4346
_LOGGER.debug(
4447
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
4548
)
46-
return target(*args)
49+
from torch._subclasses.fake_tensor import unset_fake_temporarily
50+
51+
with unset_fake_temporarily():
52+
return target(*args)
4753

4854

4955
def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool:

py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import operator
23
from typing import Callable, List, Optional, Set, Tuple
34

45
import torch
@@ -33,8 +34,7 @@ def __repr__(self):
3334

3435

3536
class ComplexOpDetector:
36-
def __init__(self, logger):
37-
self.logger = logger
37+
def __init__(self):
3838
pass
3939

4040
def is_complex_dtype(self, node: Node) -> bool:
@@ -45,15 +45,13 @@ def is_complex_dtype(self, node: Node) -> bool:
4545
if hasattr(val, "dtype"):
4646
dtype = val.dtype
4747

48-
self.logger.debug(f"dtype of node: {dtype}")
48+
logger.debug(f"dtype of node: {dtype}")
4949
return dtype in {torch.complex64, torch.complex128}
5050

5151
def node_include_in_subgraph(self, node: Node) -> bool:
5252
# Include only call_function ops on complex tensors
53-
self.logger.debug(f"node.op: {node.op}, node name: {node.name}")
54-
self.logger.debug(f"is_complex_dtype: {self.is_complex_dtype(node)}")
5553
if node.op == "call_function" and self.is_complex_dtype(node):
56-
self.logger.debug(
54+
logger.debug(
5755
f"node.op is added to subgraph: {node.op}, node name: {node.name} is complex"
5856
)
5957
return node.op == "call_function" and self.is_complex_dtype(node)
@@ -67,13 +65,11 @@ def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo:
6765
if n in subgraph_nodes:
6866
continue
6967
subgraph_nodes.add(n)
70-
self.logger.debug(f"node {n.name} is added to subgraph")
68+
logger.debug(f"node {n.name} is added to subgraph")
7169
for inp in n.all_input_nodes:
7270
if self.node_include_in_subgraph(inp):
73-
print("node inp is added to stack:", inp.name)
7471
stack.append(inp)
7572
else:
76-
print("node inp is not added to stack BUT INP:", inp.name)
7773
input_nodes.add(inp)
7874
return ComplexSubGraphInfo(
7975
[anchor_node], list(subgraph_nodes), list(input_nodes)
@@ -85,13 +81,12 @@ def find_complex_op_subgraphs(
8581
complex_op_subgraphs: List[ComplexSubGraphInfo] = []
8682
for node in gm.graph.nodes:
8783
if node.target == anchor_target:
88-
self.logger.debug(f"node.target {node.target} node.name: {node.name}")
8984
new_sub = self.subgraph_from_anchor(node)
9085
# if any intersecting nodes between seen and sub.subgraph_nodes they should be merged
9186
merged = False
9287
for existing_sub in complex_op_subgraphs:
9388
if set(existing_sub.subgraph_nodes) & set(new_sub.subgraph_nodes):
94-
self.logger.debug(f"merging subgraphs {existing_sub} {new_sub}")
89+
logger.debug(f"merging subgraphs {existing_sub} {new_sub}")
9590
# merge the two subgraphs
9691
existing_sub.subgraph_nodes = list(
9792
set(existing_sub.subgraph_nodes)
@@ -113,7 +108,7 @@ def find_complex_op_subgraphs(
113108
def complex_graph_detection(
114109
gm: GraphModule, settings: CompilationSettings
115110
) -> List[ComplexSubGraphInfo]:
116-
complex_op_detector = ComplexOpDetector(logger)
111+
complex_op_detector = ComplexOpDetector()
117112
complex_subgraphs = complex_op_detector.find_complex_op_subgraphs(
118113
gm, anchor_target=torch.ops.aten.view_as_real.default
119114
)
@@ -174,17 +169,24 @@ def replace_input_node(self, input_node):
174169

175170
elif input_node.op == "get_attr":
176171
new_attr_name = input_node.target + "_reshaped"
177-
original_tensor = self.get_attr_tensor(input_node.target)
178-
stacked_tensor = torch.stack(
179-
[original_tensor.real, original_tensor.imag], dim=-1
180-
)
181-
self.gm.register_buffer(new_attr_name, stacked_tensor)
172+
from torch._subclasses.fake_tensor import unset_fake_temporarily
173+
174+
with unset_fake_temporarily():
175+
original_tensor = self.get_attr_tensor(input_node.target)
176+
stacked_tensor = torch.stack(
177+
[original_tensor.real, original_tensor.imag], dim=-1
178+
)
179+
self.gm.register_buffer(new_attr_name, stacked_tensor)
182180
with self.gm.graph.inserting_after(input_node):
183181
new_node = self.gm.graph.get_attr(new_attr_name)
184182

185183
else:
186-
logger.debug(f"Unsupported node type: {input_node.op}")
187-
logger.debug("This node type does not need to replaced")
184+
logger.debug(
185+
f"Unsupported node type in replacement of input node: {input_node.op}"
186+
)
187+
logger.debug(
188+
"This complex subgraph inputnode type does not need to replaced"
189+
)
188190

189191
input_node.replace_all_uses_with(new_node)
190192
self.gm.graph.erase_node(input_node)
@@ -211,6 +213,8 @@ def rewrite_subgraph_nodes(self, subgraphs):
211213

212214
def match_complex_mul(
213215
match: torch.fx.subgraph_rewriter.Match,
216+
original_graph,
217+
pattern_graph,
214218
) -> bool:
215219
for original_node in match.nodes_map.values():
216220
if original_node.name == node.name:
@@ -230,10 +234,9 @@ def match_complex_mul(
230234
self.gm.graph.erase_node(node)
231235
else:
232236
logger.debug(f"Unsupported node target: {node.target}")
233-
logger.debug(f"This node type does not need to replaced")
234-
if modified:
235-
self.gm.graph.lint()
236-
self.gm.recompile()
237+
logger.debug(
238+
"This complex subgraphnode type does not need to replaced"
239+
)
237240

238241
if modified:
239242
self.gm.graph.lint()
@@ -256,16 +259,28 @@ def complex_mul_replacement() -> Tuple[
256259

257260
# Original pattern: torch.mul for complex tensors
258261
def original_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
259-
return torch.mul(x, y)
262+
return torch.ops.aten.mul.Tensor(x, y)
260263

261264
# Replacement function: manual complex multiplication on real/imag stacked tensors
262265
def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
263-
x_real, x_imag = x[..., 0], x[..., 1]
264-
y_real, y_imag = y[..., 0], y[..., 1]
265-
266-
real = x_real * y_real - x_imag * y_imag
267-
imag = x_real * y_imag + x_imag * y_real
268-
269-
return torch.stack((real, imag), dim=-1)
266+
x_real = torch.ops.aten.select.int(x, -1, 0)
267+
x_imag = torch.ops.aten.select.int(x, -1, 1) # x is reshape tensor
268+
y_real, y_imag = y[..., 0], y[..., 1] # y is frozen param
269+
270+
real_part1 = torch.ops.aten.mul.Tensor(x_real, y_real)
271+
real_part2 = torch.ops.aten.mul.Tensor(x_imag, y_imag)
272+
real = torch.ops.aten.sub.Tensor(real_part1, real_part2)
273+
274+
imag_part1 = torch.ops.aten.mul.Tensor(x_real, y_imag)
275+
imag_part2 = torch.ops.aten.mul.Tensor(x_imag, y_real)
276+
imag = torch.ops.aten.add.Tensor(imag_part1, imag_part2)
277+
278+
return torch.ops.aten.cat.default(
279+
[
280+
torch.ops.aten.unsqueeze.default(real, -1),
281+
torch.ops.aten.unsqueeze.default(imag, -1),
282+
],
283+
dim=-1,
284+
)
270285

271286
return (original_mul, replacement)

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,97 @@ def forward(self, input, mat1, mat2):
237237
torch._dynamo.reset()
238238

239239

240+
class TestComplexSubgraph(TestCase):
241+
def test_complex_subgraph(self):
242+
def rotary_embedding(x, dim, freqs_cis=None):
243+
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
244+
x_out_flatten = torch.view_as_real(x_ * freqs_cis).flatten(3)
245+
return x_out_flatten.type_as(x_)
246+
247+
def _freqs_ex_tensor():
248+
real = torch.tensor([[[[1.0000]], [[2.0000]]]], device="cuda")
249+
imag = torch.tensor([[[[0.0000]], [[3.0000]]]], device="cuda")
250+
251+
z = torch.complex(real, imag)
252+
return z
253+
254+
class RotaryAttention(torch.nn.Module):
255+
def __init__(self, dim, seq_len):
256+
super().__init__()
257+
self.dim = dim
258+
self.wq = torch.nn.Linear(dim, dim)
259+
self.seq_len = seq_len
260+
self._freqs_ex_tensor = _freqs_ex_tensor()
261+
262+
self.register_buffer(
263+
"freqs_ex_tensor",
264+
self._freqs_ex_tensor,
265+
persistent=True,
266+
)
267+
268+
def forward(self, x):
269+
q = self.wq(x)
270+
freqs_cis = self._freqs_ex_tensor.to(q.device)
271+
q_out = rotary_embedding(q, self.dim, freqs_cis)
272+
return q_out
273+
274+
BATCH = 1
275+
SEQ_LEN = 2
276+
HEADS = 1
277+
DIM = 2
278+
279+
inputs = [torch.randn(BATCH, SEQ_LEN, HEADS, DIM).cuda()]
280+
281+
fx_graph = torch.fx.symbolic_trace(RotaryAttention(DIM, SEQ_LEN))
282+
expected_ops = {torch.ops.aten.mul.Tensor}
283+
unexpected_ops = {
284+
torch.ops.aten.view_as_complex.default,
285+
torch.ops.aten.view_as_real.default,
286+
}
287+
288+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
289+
fx_graph,
290+
inputs,
291+
expected_ops=expected_ops,
292+
unexpected_ops=unexpected_ops,
293+
min_block_size=1,
294+
)
295+
296+
self.assertEqual(
297+
len(unexpected_ops_seen),
298+
0,
299+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
300+
)
301+
302+
self.assertEqual(
303+
len(expected_ops_unseen),
304+
0,
305+
f"The following expected ops were not encountered: {expected_ops_unseen}",
306+
)
307+
torch._dynamo.reset()
308+
309+
# Validate that the results between Torch and Torch-TRT are similar
310+
optimized_model = torch_tensorrt.compile(
311+
fx_graph,
312+
"torch_compile",
313+
inputs,
314+
min_block_size=1,
315+
pass_through_build_failures=True,
316+
)
317+
optimized_model_results = optimized_model(*inputs)[0].detach().cpu()
318+
torch_model_results = fx_graph(*inputs)[0].detach().cpu()
319+
320+
max_diff = float(
321+
torch.max(torch.abs(optimized_model_results - torch_model_results))
322+
)
323+
self.assertAlmostEqual(
324+
max_diff,
325+
0,
326+
DECIMALS_OF_AGREEMENT,
327+
msg=f"ComplexSubgraph TRT outputs don't match with the original model.",
328+
)
329+
torch._dynamo.reset()
330+
331+
240332
if __name__ == "__main__":
241333
run_tests()

0 commit comments

Comments
 (0)