|
5 | 5 | import torch
|
6 | 6 | import torch.nn as nn
|
7 | 7 | import torch_tensorrt
|
| 8 | +import torch.distributed as dist |
8 | 9 | from torch.distributed._tensor import Shard
|
9 | 10 | from torch.distributed._tensor.device_mesh import init_device_mesh
|
10 | 11 | from torch.distributed.tensor.parallel import (
|
11 | 12 | ColwiseParallel,
|
12 | 13 | RowwiseParallel,
|
13 | 14 | parallelize_module,
|
14 | 15 | )
|
15 |
| - |
| 16 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( |
| 17 | + dynamo_tensorrt_converter, |
| 18 | +) |
| 19 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
| 20 | +from torch.fx.node import Target, Argument |
| 21 | +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union |
| 22 | +from torch_tensorrt.dynamo.types import TRTTensor |
| 23 | +import numpy as np |
| 24 | +from torch_tensorrt.fx.converters.converter_utils import ( |
| 25 | + set_layer_name, |
| 26 | +) |
| 27 | +import tensorrt as trt |
| 28 | +import tensorrt_llm |
| 29 | +import ctypes |
| 30 | +import logging |
16 | 31 | """
|
17 | 32 | This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
|
18 | 33 | """
|
19 | 34 |
|
| 35 | +plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" |
| 36 | +try: |
| 37 | + ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so") |
| 38 | + print("plugin loaded sucessfully") |
| 39 | +except OSError as e: |
| 40 | + print(f"unsuccessful load : {e}") |
| 41 | +logger = trt.Logger(trt.Logger.VERBOSE) |
| 42 | +trt.init_libnvinfer_plugins(None, '') |
| 43 | +#-[p;Iterate over all registered plugin creators |
| 44 | +plugin_registry = trt.get_plugin_registry() |
| 45 | +for plugin_creator in plugin_registry.plugin_creator_list: |
| 46 | + print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}") |
| 47 | + |
| 48 | + |
| 49 | +@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default) |
| 50 | +def insert_gather_op( |
| 51 | + ctx: ConversionContext, |
| 52 | + target: Target, |
| 53 | + args: Tuple[Argument, ...], |
| 54 | + kwargs: Dict[str, Argument], |
| 55 | + name: str, |
| 56 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 57 | + plug_inputs = [args[0]] |
| 58 | + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 59 | + "AllGather", "1", "tensorrt_llm" |
| 60 | + ) |
| 61 | + assert allgather_plg_creator is not None |
| 62 | + world_size = dist.get_world_size() |
| 63 | + group = list(range(world_size)) |
| 64 | + group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32) |
| 65 | + p_dtype = trt.float16 |
| 66 | + pf_type = trt.PluginField( |
| 67 | + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 |
| 68 | + ) |
| 69 | + pfc = trt.PluginFieldCollection([group, pf_type]) |
| 70 | + allgather = allgather_plg_creator.create_plugin("allgather", pfc) |
| 71 | + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) |
| 72 | + set_layer_name(layer, target, name) |
| 73 | + return layer.get_output(0) |
| 74 | + |
20 | 75 |
|
21 | 76 | class ToyModel(nn.Module):
|
22 | 77 | """MLP based model"""
|
|
0 commit comments