|
14 | 14 | # See the License for the specific language governing permissions and
|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
| 17 | +from typing import Optional |
17 | 18 |
|
18 | 19 | import torch
|
19 | 20 | import torch.distributed as dist
|
| 21 | +from torch.distributed import ProcessGroup |
| 22 | +from vllm.distributed.device_communicators.base_device_communicator import \ |
| 23 | + DeviceCommunicatorBase |
20 | 24 |
|
21 | 25 |
|
22 |
| -class NPUCommunicator: |
| 26 | +class NPUCommunicator(DeviceCommunicatorBase): |
23 | 27 |
|
24 |
| - def __init__(self, group, unique_name=""): |
25 |
| - self.group = group |
26 |
| - self.unique_name = unique_name |
27 |
| - self.rank = dist.get_rank(group) |
28 |
| - self.world_size = dist.get_world_size(self.group) |
29 |
| - self.ranks = dist.get_process_group_ranks(self.group) |
30 |
| - global_rank = dist.get_rank() |
31 |
| - self.rank_in_group = dist.get_group_rank(self.group, global_rank) |
32 |
| - |
33 |
| - def all_reduce(self, x: torch.Tensor) -> torch.Tensor: |
34 |
| - dist.all_reduce(x, group=self.group) |
35 |
| - return x |
36 |
| - |
37 |
| - def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1): |
38 |
| - # NOTE: We assume that the input tensor is on the same device across |
39 |
| - # all the ranks. |
40 |
| - # NOTE: `dst` is the local rank of the destination rank. |
41 |
| - # Allocate output tensor. |
42 |
| - if self.rank_in_group == dst: |
43 |
| - gather_list = [ |
44 |
| - torch.empty_like(input_) for _ in range(self.world_size) |
45 |
| - ] |
46 |
| - else: |
47 |
| - gather_list = None |
48 |
| - # Gather. |
49 |
| - dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group) |
50 |
| - if self.rank_in_group == dst: |
51 |
| - output_tensor = torch.cat(gather_list, dim=dim) |
52 |
| - else: |
53 |
| - output_tensor = None |
54 |
| - return output_tensor |
55 |
| - |
56 |
| - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: |
57 |
| - if dim < 0: |
58 |
| - # Convert negative dim to positive. |
59 |
| - dim += input_.dim() |
60 |
| - input_size = input_.size() |
61 |
| - # NOTE: we have to use concat-style all-gather here, |
62 |
| - # stack-style all-gather has compatibility issues with |
63 |
| - # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 |
64 |
| - output_size = (input_size[0] * self.world_size, ) + input_size[1:] |
65 |
| - # Allocate output tensor. |
66 |
| - output_tensor = torch.empty(output_size, |
67 |
| - dtype=input_.dtype, |
68 |
| - device=input_.device) |
69 |
| - # All-gather. |
70 |
| - dist.all_gather_into_tensor(output_tensor, input_, group=self.group) |
71 |
| - # Reshape |
72 |
| - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) |
73 |
| - output_tensor = output_tensor.movedim(0, dim) |
74 |
| - output_tensor = output_tensor.reshape(input_size[:dim] + |
75 |
| - (self.world_size * |
76 |
| - input_size[dim], ) + |
77 |
| - input_size[dim + 1:]) |
78 |
| - return output_tensor |
| 28 | + def __init__(self, |
| 29 | + cpu_group: ProcessGroup, |
| 30 | + device: Optional[torch.device] = None, |
| 31 | + device_group: Optional[ProcessGroup] = None, |
| 32 | + unique_name: str = ""): |
| 33 | + super().__init__(cpu_group, device, device_group, unique_name) |
| 34 | + # init device according to local rank |
| 35 | + local_rank = dist.get_rank(device_group) |
| 36 | + self.device = torch.device(f"npu:{local_rank}") |
0 commit comments