Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 5bde5ba

Browse files
youkaichaoRobert Shaw
authored andcommitted
[Core][Distributed] improve p2p access check (vllm-project#4992)
1 parent 420c4ff commit 5bde5ba

File tree

3 files changed

+189
-90
lines changed

3 files changed

+189
-90
lines changed

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from torch.distributed import ProcessGroup
77

88
import vllm.envs as envs
9+
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
10+
gpu_p2p_access_check)
911
from vllm.distributed.parallel_state import (
1012
get_local_rank, get_tensor_model_parallel_cpu_group)
1113
from vllm.logger import init_logger
@@ -65,7 +67,6 @@ def _is_full_nvlink(device_ids: List[int]) -> bool:
6567

6668

6769
def _can_p2p(rank: int, world_size: int) -> bool:
68-
from vllm.distributed.utils import gpu_p2p_access_check
6970
for i in range(world_size):
7071
if i == rank:
7172
continue
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import json
2+
import os
3+
import sys
4+
import tempfile
5+
import time
6+
from contextlib import contextmanager
7+
from typing import Callable, Dict, List, Optional
8+
9+
import torch
10+
import torch.distributed as dist
11+
import torch.multiprocessing as mp
12+
13+
import vllm.envs as envs
14+
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
15+
from vllm.logger import init_logger
16+
17+
logger = init_logger(__name__)
18+
19+
20+
@contextmanager
21+
def mute_output():
22+
with open(os.devnull, "w") as f:
23+
sys.stderr = f
24+
sys.stdout = f
25+
yield
26+
27+
28+
def producer(i: int,
29+
init_method: str,
30+
cuda_visible_devices: Optional[str] = None):
31+
if cuda_visible_devices is not None:
32+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
33+
with mute_output():
34+
dist.init_process_group(
35+
backend="gloo",
36+
init_method=init_method,
37+
world_size=2,
38+
rank=0,
39+
)
40+
# produce a tensor in GPU i
41+
data = torch.zeros((128, ), device=f"cuda:{i}")
42+
# get the information to reconstruct the shared tensor
43+
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
44+
args = list(args)
45+
dist.broadcast_object_list([(func, args)], src=0)
46+
dist.barrier()
47+
torch.cuda.synchronize()
48+
assert torch.all(data == 1).item()
49+
50+
51+
def consumer(j: int,
52+
init_method: str,
53+
cuda_visible_devices: Optional[str] = None):
54+
if cuda_visible_devices is not None:
55+
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
56+
with mute_output():
57+
dist.init_process_group(
58+
backend="gloo",
59+
init_method=init_method,
60+
world_size=2,
61+
rank=1,
62+
)
63+
torch.cuda.set_device(j)
64+
recv = [None]
65+
dist.broadcast_object_list(recv, src=0)
66+
func: Callable
67+
args: List
68+
func, args = recv[0] # type: ignore
69+
# `args[6]` is the device id
70+
# by default pytorch will use `i` from the producer
71+
# here we need to set it to `j` to test P2P access
72+
args[6] = j
73+
data = func(*args)
74+
data += 1
75+
dist.barrier()
76+
torch.cuda.synchronize()
77+
assert torch.all(data == 1).item()
78+
79+
80+
def can_actually_p2p(i, j):
81+
"""
82+
Usually, checking if P2P access is enabled can be done by
83+
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
84+
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)`
85+
returns `True` even if P2P access is not actually possible.
86+
See https://github.com/vllm-project/vllm/issues/2728 and
87+
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
88+
Therefore, we have to perform a real P2P access to check if it is actually
89+
possible.
90+
91+
Note on p2p and cuda IPC:
92+
Usually, one process uses one GPU:
93+
GPU i --> cuda context i --> tensor i --> process i
94+
95+
We need to combine p2p and cuda IPC, so that:
96+
GPU i --> cuda context i --> tensor i --> process i
97+
|shared|
98+
GPU j --> cuda context j --> tensor j --> process j
99+
That is to say, process i creates a tensor in GPU i, passes IPC handle to
100+
process j, and process j accesses the tensor in GPU j. Any operation on the
101+
tensor in process j will be reflected in the tensor in process i, because
102+
they are the same memory segment.
103+
It is important to note that process j accesses the tensor in GPU j, not
104+
GPU i. That's why we need p2p access. # noqa
105+
"""
106+
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
107+
# pass the CUDA_VISIBLE_DEVICES to the child process
108+
# to make sure they see the same set of GPUs
109+
110+
# make sure the temp file is not the same across different calls
111+
temp_path = tempfile.mktemp() + str(time.time())
112+
# create an empty file
113+
with open(temp_path, "w"):
114+
pass
115+
init_method = f"file://{temp_path}"
116+
117+
# make sure the processes are spawned
118+
smp = mp.get_context("spawn")
119+
pi = smp.Process(target=producer,
120+
args=(i, init_method, cuda_visible_devices))
121+
pj = smp.Process(target=consumer,
122+
args=(j, init_method, cuda_visible_devices))
123+
pi.start()
124+
pj.start()
125+
pi.join()
126+
pj.join()
127+
return pi.exitcode == 0 and pj.exitcode == 0
128+
129+
130+
# why do we need this cache?
131+
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
132+
# if we test it every time, it will be very slow, because we need to create
133+
# N * N * 2 processes, where N is the world size. This is very slow.
134+
# to reduce the time, we use a cache file to store the p2p access status.
135+
# the cache file is generated by the master process if it does not exist.
136+
# then all the processes can read the cache file to check the p2p access status.
137+
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
138+
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
139+
# e.g. used by different vllm engines. The device id in the cache file is a
140+
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
141+
# of visible devices in the vllm engine.
142+
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
143+
144+
145+
def gpu_p2p_access_check(i: int, j: int) -> bool:
146+
"""Check if GPU i can access GPU j."""
147+
148+
# if the cache variable is already calculated,
149+
# read from the cache instead of checking it again
150+
global _gpu_p2p_access_cache
151+
if _gpu_p2p_access_cache is not None:
152+
return _gpu_p2p_access_cache[f"{i}->{j}"]
153+
154+
is_distributed = dist.is_initialized()
155+
156+
num_dev = torch.cuda.device_count()
157+
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
158+
if cuda_visible_devices is None:
159+
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
160+
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
161+
path = os.path.expanduser(
162+
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
163+
)
164+
os.makedirs(os.path.dirname(path), exist_ok=True)
165+
if ((not is_distributed or get_local_rank() == 0)
166+
and (not os.path.exists(path))):
167+
# only the local master process (with local_rank == 0) can
168+
# enter this block to calculate the cache
169+
logger.info("generating GPU P2P access cache for in %s", path)
170+
cache = {}
171+
for _i in range(num_dev):
172+
for _j in range(num_dev):
173+
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j)
174+
with open(path, "w") as f:
175+
json.dump(cache, f, indent=4)
176+
if is_distributed:
177+
cpu_world_group = get_cpu_world_group()
178+
dist.barrier(cpu_world_group)
179+
logger.info("reading GPU P2P access cache from %s", path)
180+
with open(path, "r") as f:
181+
cache = json.load(f)
182+
_gpu_p2p_access_cache = cache
183+
return _gpu_p2p_access_cache[f"{i}->{j}"]
184+
185+
186+
__all__ = ["gpu_p2p_access_check"]

vllm/distributed/utils.py

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,9 @@
22
# Adapted from
33
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
44
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5-
import json
6-
import os
7-
from typing import Dict, Optional, Sequence
5+
from typing import Sequence
86

97
import torch
10-
import torch.distributed as dist
11-
12-
import vllm.envs as envs
13-
from vllm.logger import init_logger
14-
15-
from .parallel_state import get_cpu_world_group, get_local_rank
16-
17-
logger = init_logger(__name__)
188

199

2010
def ensure_divisibility(numerator, denominator):
@@ -56,81 +46,3 @@ def split_tensor_along_last_dim(
5646
return tuple(chunk.contiguous() for chunk in tensor_list)
5747

5848
return tensor_list
59-
60-
61-
# code partly borrowed from
62-
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
63-
# License: MIT
64-
def _can_actually_p2p(idx_a, idx_b):
65-
dev_i = f"cuda:{idx_a}"
66-
dev_j = f"cuda:{idx_b}"
67-
a = torch.randn(5, device=dev_i) + 123.0
68-
b = a.to(dev_j)
69-
c = b.to(dev_i)
70-
return torch.all(a == c).cpu().item()
71-
72-
73-
# why do we need this cache?
74-
# 1. we can have runtime checks for P2P access, where every process checks
75-
# P2P access to all other GPUs. Unfortunately, the test might cost many
76-
# (world_size * world_size) cuda context, and reduce the memory available
77-
# for the model. see https://github.com/vllm-project/vllm/issues/3821
78-
# 2. alternatively, we can have a p2p map that is generated by the master
79-
# process and broadcasted to all other processes. This still requires
80-
# #world_size of cuda context, belonging to the master process, on each GPU.
81-
# 3. we can have a cache file, that records the p2p access status. The first
82-
# time the master process checks the p2p access, it will generate the cache
83-
# file, at the cost of #world_size of cuda context. Later on, all processes
84-
# can read the cache file to check the p2p access status without any cost of
85-
# additional cuda context.
86-
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
87-
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
88-
# e.g. used by different vllm engines. The device id in the cache file is a
89-
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
90-
# of visible devices in the vllm engine.
91-
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
92-
93-
94-
def gpu_p2p_access_check(i: int, j: int) -> bool:
95-
"""Check if GPU i can access GPU j."""
96-
97-
# if the cache variable is already calculated,
98-
# read from the cache instead of checking it again
99-
global _gpu_p2p_access_cache
100-
if _gpu_p2p_access_cache is not None:
101-
return _gpu_p2p_access_cache[f"{i}->{j}"]
102-
103-
is_distributed = dist.is_initialized()
104-
105-
num_dev = torch.cuda.device_count()
106-
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
107-
if cuda_visible_devices is None:
108-
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
109-
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
110-
path = os.path.expanduser(
111-
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
112-
)
113-
os.makedirs(os.path.dirname(path), exist_ok=True)
114-
if (not is_distributed or get_local_rank() == 0) \
115-
and (not os.path.exists(path)):
116-
# only the local master process (with local_rank == 0) can
117-
# enter this block to calculate the cache
118-
logger.info("generating GPU P2P access cache for in %s", path)
119-
cache = {}
120-
for _i in range(num_dev):
121-
for _j in range(num_dev):
122-
# on some platforms, P2P support might be buggy and we need
123-
# additional checks. See also:
124-
# https://github.com/vllm-project/vllm/issues/2728
125-
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
126-
_i, _j) and _can_actually_p2p(_i, _j)
127-
with open(path, "w") as f:
128-
json.dump(cache, f, indent=4)
129-
if is_distributed:
130-
cpu_world_group = get_cpu_world_group()
131-
dist.barrier(cpu_world_group)
132-
logger.info("reading GPU P2P access cache from %s", path)
133-
with open(path, "r") as f:
134-
cache = json.load(f)
135-
_gpu_p2p_access_cache = cache
136-
return _gpu_p2p_access_cache[f"{i}->{j}"]

0 commit comments

Comments
 (0)