Skip to content

Commit 195b1c4

Browse files
committed
using nccl ops from TRT-LLM namespace
1 parent 1bb044f commit 195b1c4

File tree

3 files changed

+126
-1
lines changed

3 files changed

+126
-1
lines changed

examples/distributed_inference/tensor_parallel_simple_example.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,73 @@
55
import torch
66
import torch.nn as nn
77
import torch_tensorrt
8+
import torch.distributed as dist
89
from torch.distributed._tensor import Shard
910
from torch.distributed._tensor.device_mesh import init_device_mesh
1011
from torch.distributed.tensor.parallel import (
1112
ColwiseParallel,
1213
RowwiseParallel,
1314
parallelize_module,
1415
)
15-
16+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
17+
dynamo_tensorrt_converter,
18+
)
19+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
20+
from torch.fx.node import Target, Argument
21+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
22+
from torch_tensorrt.dynamo.types import TRTTensor
23+
import numpy as np
24+
from torch_tensorrt.fx.converters.converter_utils import (
25+
set_layer_name,
26+
)
27+
import tensorrt as trt
28+
import tensorrt_llm
29+
import ctypes
30+
import logging
1631
"""
1732
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
1833
"""
1934

35+
plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
36+
try:
37+
ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
38+
print("plugin loaded sucessfully")
39+
except OSError as e:
40+
print(f"unsuccessful load : {e}")
41+
logger = trt.Logger(trt.Logger.VERBOSE)
42+
trt.init_libnvinfer_plugins(None, '')
43+
#-[p;Iterate over all registered plugin creators
44+
plugin_registry = trt.get_plugin_registry()
45+
for plugin_creator in plugin_registry.plugin_creator_list:
46+
print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")
47+
48+
49+
@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default)
50+
def insert_gather_op(
51+
ctx: ConversionContext,
52+
target: Target,
53+
args: Tuple[Argument, ...],
54+
kwargs: Dict[str, Argument],
55+
name: str,
56+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
57+
plug_inputs = [args[0]]
58+
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
59+
"AllGather", "1", "tensorrt_llm"
60+
)
61+
assert allgather_plg_creator is not None
62+
world_size = dist.get_world_size()
63+
group = list(range(world_size))
64+
group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
65+
p_dtype = trt.float16
66+
pf_type = trt.PluginField(
67+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
68+
)
69+
pfc = trt.PluginFieldCollection([group, pf_type])
70+
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
71+
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
72+
set_layer_name(layer, target, name)
73+
return layer.get_output(0)
74+
2075

2176
class ToyModel(nn.Module):
2277
"""MLP based model"""

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .accumulate_fp32_matmul import accumulate_fp32_matmul
77
from .constant_folding import constant_fold
88
from .fuse_prims_broadcast import fuse_prims_broadcast
9+
from .fuse_distributed_ops import fuse_distributed_ops
910
from .lower_linear import lower_linear
1011
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
1112
from .pass_manager import DynamoPassManager
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
from typing import Sequence
3+
4+
import torch
5+
6+
# dead-code elimination, linting, and recompilation for graph, in-place
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def custom_fused_all_gather_op(args0, args1, args2):
15+
return torch.ops._c10d_functional.wait_tensor.default(
16+
torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2)
17+
)
18+
19+
20+
def custom_fused_reduce_scatter_op(args0, args1, args2, args3):
21+
return torch.ops._c10d_functional.wait_tensor.default(
22+
torch.ops._c10d_functional.reduce_scatter_tensor.default(
23+
args0, args1, args2, args3
24+
)
25+
)
26+
27+
28+
def fuse_distributed_ops(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
29+
modified_graph = False
30+
for node in gm.graph.nodes:
31+
if (
32+
node.target
33+
in (
34+
torch.ops._c10d_functional.all_gather_into_tensor.default,
35+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
36+
)
37+
and len(node.users) == 1
38+
and list(node.users)[0].target
39+
== torch.ops._c10d_functional.wait_tensor.default
40+
):
41+
wait_tensor_node = list(node.users)[0]
42+
fused_op = None
43+
if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default:
44+
fused_op = custom_fused_all_gather_op
45+
fused_op_args = (node.args[0], node.args[1], node.args[2])
46+
else:
47+
fused_op = custom_fused_reduce_scatter_op
48+
fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3])
49+
with gm.graph.inserting_after(wait_tensor_node):
50+
fused_node = gm.graph.create_node(
51+
op="call_function",
52+
target=fused_op, # Define your custom fused function
53+
args=fused_op_args,
54+
)
55+
56+
wait_tensor_node.replace_all_uses_with(fused_node)
57+
fused_node.meta.update(node.meta)
58+
modified_graph = True
59+
gm.graph.erase_node(wait_tensor_node)
60+
gm.graph.erase_node(node)
61+
62+
# If graph was modified, clean it up
63+
if modified_graph:
64+
gm = clean_up_graph_after_modifications(gm)
65+
logger.debug(
66+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
67+
)
68+
69+
return gm

0 commit comments

Comments
 (0)