1+ import ctypes
2+ import logging
13import os
4+ import site
25import sys
36import time
7+ from enum import IntEnum , IntFlag , auto
8+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
49
10+ import numpy as np
11+ import tensorrt as trt
12+ import tensorrt_llm
513import torch
14+ import torch .distributed as dist
615import torch .nn as nn
716import torch_tensorrt
817from torch .distributed ._tensor import Shard
1221 RowwiseParallel ,
1322 parallelize_module ,
1423)
24+ from torch .fx import GraphModule , Node
25+ from torch .fx .node import Argument , Target
26+ from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
27+ from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
28+ dynamo_tensorrt_converter ,
29+ )
30+ from torch_tensorrt .dynamo .lowering .passes .fuse_distributed_ops import (
31+ custom_fused_all_gather_op ,
32+ custom_fused_reduce_scatter_op ,
33+ )
34+ from torch_tensorrt .dynamo .types import TRTTensor
35+ from torch_tensorrt .fx .converters .converter_utils import set_layer_name
36+
37+
38+ # This is required for env initialization since we use mpirun
39+ def initialize (rank = 0 , world_size = 1 , port = 29500 ):
40+ local_rank = int (
41+ os .environ .get ("OMPI_COMM_WORLD_LOCAL_RANK" , rank % torch .cuda .device_count ())
42+ )
43+ world_size = int (os .environ .get ("OMPI_COMM_WORLD_SIZE" , world_size ))
44+
45+ # Set up environment variable to run with mpirun
46+ os .environ ["RANK" ] = str (local_rank )
47+ os .environ ["WORLD_SIZE" ] = str (world_size )
48+ os .environ ["MASTER_ADDR" ] = "127.0.0.1"
49+ os .environ ["MASTER_PORT" ] = str (port )
50+
51+ # Necessary to assign a device to each rank.
52+ torch .cuda .set_device (local_rank )
53+
54+ # We use nccl backend
55+ dist .init_process_group ("nccl" )
56+
57+ # set a manual seed for reproducibility
58+ torch .manual_seed (1111 )
59+
60+ return local_rank , world_size
61+
62+
63+ initialize ()
64+ # create a device mesh based on the given world_size.
65+ _world_size = int (os .environ ["WORLD_SIZE" ])
66+
67+ device_mesh = init_device_mesh (device_type = "cuda" , mesh_shape = (_world_size ,))
68+ _rank = device_mesh .get_rank ()
69+ device_id = _rank % torch .cuda .device_count () # Ensure each rank gets a unique device
70+ torch .cuda .set_device (device_id )
71+
72+
73+ logger = logging .getLogger ()
74+ logger .setLevel (logging .INFO )
75+ fh = logging .FileHandler (f"./tensor_parallel_simple_example_{ _rank } .log" , mode = "w" )
76+ fh .setLevel (logging .INFO )
77+ logger .addHandler (fh )
78+
79+
80+ # TensorRT NCCL plugins
81+ tensorrt_llm_lib_path = tensorrt_llm .__file__
82+ plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so"
83+ try :
84+ ctypes .CDLL (plugin_lib_path )
85+ logger .info (f"plugin loaded successfully" )
86+ except OSError as e :
87+ logger .info (f"unsuccessful load : { e } " )
88+ trt .init_libnvinfer_plugins (None , "" )
89+ # Iterate over all registered plugin creators
90+ plugin_registry = trt .get_plugin_registry ()
91+ for plugin_creator in plugin_registry .plugin_creator_list :
92+ logger .info (
93+ f"Plugin Name: { plugin_creator .name } , Namespace: { plugin_creator .plugin_namespace } , Version: { plugin_creator .plugin_version } "
94+ )
95+
96+
97+ # class for AllReduce
98+ class AllReduceStrategy (IntEnum ):
99+ """Warning: actual definition is in kernels/customAllReduceKernels.h.
100+
101+ They must be kept in sync.
102+ """
103+
104+ NCCL = 0
105+ ONESHOT = 1
106+ TWOSHOT = 2
107+ AUTO = 3
108+
109+
110+ class AllReduceConfig (IntFlag ):
111+ """Warning: actual definition is in kernels/customAllReduceKernels.h.
112+
113+ They must be kept in sync
114+ """
115+
116+ USE_MEMCPY = auto ()
117+ PUSH_MODE = auto ()
118+
119+
120+ @dynamo_tensorrt_converter (custom_fused_all_gather_op )
121+ def insert_nccl_gather_op (
122+ ctx : ConversionContext ,
123+ target : Target ,
124+ args : Tuple [Argument , ...],
125+ kwargs : Dict [str , Argument ],
126+ name : str ,
127+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
128+ plug_inputs = [args [0 ]]
129+ allgather_plg_creator = trt .get_plugin_registry ().get_plugin_creator (
130+ "AllGather" , "1" , "tensorrt_llm"
131+ )
132+ assert allgather_plg_creator is not None
133+ world_size = dist .get_world_size ()
134+ group = list (range (world_size ))
135+ group = trt .PluginField (
136+ "group" , np .array (group , dtype = np .int32 ), trt .PluginFieldType .INT32
137+ )
138+ p_dtype = trt .float16
139+ pf_type = trt .PluginField (
140+ "type_id" , np .array ([int (p_dtype )], np .int32 ), trt .PluginFieldType .INT32
141+ )
142+ pfc = trt .PluginFieldCollection ([group , pf_type ])
143+ allgather = allgather_plg_creator .create_plugin ("allgather" , pfc )
144+ layer = ctx .net .add_plugin_v2 (plug_inputs , allgather )
145+ set_layer_name (layer , target , name )
146+ return layer .get_output (0 )
147+
148+
149+ @dynamo_tensorrt_converter (custom_fused_reduce_scatter_op )
150+ def insert_nccl_reduce_scatter_plugin (
151+ ctx : ConversionContext ,
152+ target : Target ,
153+ args : Tuple [Argument , ...],
154+ kwargs : Dict [str , Argument ],
155+ name : str ,
156+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
157+ plug_inputs = [args [0 ]]
158+ allreduce_plg_creator = trt .get_plugin_registry ().get_plugin_creator (
159+ "ReduceScatter" , "1" , "tensorrt_llm"
160+ )
161+
162+ assert allreduce_plg_creator is not None
163+
164+ counter = 0
165+ strategy = AllReduceStrategy .NCCL
166+ config = AllReduceConfig (0 )
167+
168+ world_size = dist .get_world_size ()
169+ group = list (range (world_size ))
170+ group = trt .PluginField (
171+ "group" , np .array (group , dtype = np .int32 ), trt .PluginFieldType .INT32
172+ )
173+
174+ p_dtype = trt .float16
175+ pf_dtype = trt .PluginField (
176+ "type_id" , np .array ([int (p_dtype )], np .int32 ), trt .PluginFieldType .INT32
177+ )
178+ pfc = [group , pf_dtype ]
179+ p_strategy = trt .PluginField (
180+ "strategy" , np .array ([int (strategy )], np .int8 ), trt .PluginFieldType .INT8
181+ )
182+ pfc .append (p_strategy )
183+ p_config = trt .PluginField (
184+ "config" , np .array ([int (config )], np .int8 ), trt .PluginFieldType .INT8
185+ )
186+ pfc .append (p_config )
187+ p_counter = trt .PluginField (
188+ "counter" , np .array ([counter ], np .int32 ), trt .PluginFieldType .INT32
189+ )
190+ pfc .append (p_counter )
191+
192+ pfc = trt .PluginFieldCollection (pfc )
193+ ar_plug = allreduce_plg_creator .create_plugin ("allreduce" , pfc )
194+
195+ layer = ctx .net .add_plugin_v2 (plug_inputs , ar_plug )
196+ set_layer_name (layer , target , name )
197+ return layer .get_output (0 )
198+
15199
16200"""
17201This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
@@ -36,13 +220,6 @@ def forward(self, x):
36220 return x
37221
38222
39- # create a device mesh based on the given world_size.
40- _world_size = int (os .environ ["WORLD_SIZE" ])
41-
42- device_mesh = init_device_mesh (device_type = "cuda" , mesh_shape = (_world_size ,))
43- _rank = device_mesh .get_rank ()
44-
45-
46223print (f"Starting PyTorch TP example on rank { _rank } ." )
47224assert (
48225 _world_size % 2 == 0
0 commit comments