Skip to content

Commit f516cd2

Browse files
authored
[dist] revert communicator patch (vllm-project#66)
### What this PR does / why we need it? Revert communicator patch as vllm-project/vllm#13208 has been merged. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? test locally by vllm-project#30 (comment) Signed-off-by: MengqingCao <[email protected]>
1 parent 54a2a8b commit f516cd2

File tree

4 files changed

+14
-145
lines changed

4 files changed

+14
-145
lines changed

vllm_ascend/communicator.py

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,65 +14,23 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
from typing import Optional
1718

1819
import torch
1920
import torch.distributed as dist
21+
from torch.distributed import ProcessGroup
22+
from vllm.distributed.device_communicators.base_device_communicator import \
23+
DeviceCommunicatorBase
2024

2125

22-
class NPUCommunicator:
26+
class NPUCommunicator(DeviceCommunicatorBase):
2327

24-
def __init__(self, group, unique_name=""):
25-
self.group = group
26-
self.unique_name = unique_name
27-
self.rank = dist.get_rank(group)
28-
self.world_size = dist.get_world_size(self.group)
29-
self.ranks = dist.get_process_group_ranks(self.group)
30-
global_rank = dist.get_rank()
31-
self.rank_in_group = dist.get_group_rank(self.group, global_rank)
32-
33-
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
34-
dist.all_reduce(x, group=self.group)
35-
return x
36-
37-
def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1):
38-
# NOTE: We assume that the input tensor is on the same device across
39-
# all the ranks.
40-
# NOTE: `dst` is the local rank of the destination rank.
41-
# Allocate output tensor.
42-
if self.rank_in_group == dst:
43-
gather_list = [
44-
torch.empty_like(input_) for _ in range(self.world_size)
45-
]
46-
else:
47-
gather_list = None
48-
# Gather.
49-
dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group)
50-
if self.rank_in_group == dst:
51-
output_tensor = torch.cat(gather_list, dim=dim)
52-
else:
53-
output_tensor = None
54-
return output_tensor
55-
56-
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
57-
if dim < 0:
58-
# Convert negative dim to positive.
59-
dim += input_.dim()
60-
input_size = input_.size()
61-
# NOTE: we have to use concat-style all-gather here,
62-
# stack-style all-gather has compatibility issues with
63-
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
64-
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
65-
# Allocate output tensor.
66-
output_tensor = torch.empty(output_size,
67-
dtype=input_.dtype,
68-
device=input_.device)
69-
# All-gather.
70-
dist.all_gather_into_tensor(output_tensor, input_, group=self.group)
71-
# Reshape
72-
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
73-
output_tensor = output_tensor.movedim(0, dim)
74-
output_tensor = output_tensor.reshape(input_size[:dim] +
75-
(self.world_size *
76-
input_size[dim], ) +
77-
input_size[dim + 1:])
78-
return output_tensor
28+
def __init__(self,
29+
cpu_group: ProcessGroup,
30+
device: Optional[torch.device] = None,
31+
device_group: Optional[ProcessGroup] = None,
32+
unique_name: str = ""):
33+
super().__init__(cpu_group, device, device_group, unique_name)
34+
# init device according to local rank
35+
local_rank = dist.get_rank(device_group)
36+
self.device = torch.device(f"npu:{local_rank}")

vllm_ascend/patch/__init__.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

vllm_ascend/patch/patch_commnicator.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

vllm_ascend/worker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,6 @@ def init_worker_distributed_environment(
457457
backend: str = "hccl") -> None:
458458
"""Initialize the distributed environment."""
459459
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
460-
# register communicator patch before init dist env
461-
from vllm_ascend import patch # noqa: F401
462460

463461
init_distributed_environment(parallel_config.world_size, rank,
464462
distributed_init_method, local_rank, backend)

0 commit comments

Comments
 (0)