|
| 1 | +import json |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import tempfile |
| 5 | +import time |
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import Callable, Dict, List, Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.distributed as dist |
| 11 | +import torch.multiprocessing as mp |
| 12 | + |
| 13 | +import vllm.envs as envs |
| 14 | +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank |
| 15 | +from vllm.logger import init_logger |
| 16 | + |
| 17 | +logger = init_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +@contextmanager |
| 21 | +def mute_output(): |
| 22 | + with open(os.devnull, "w") as f: |
| 23 | + sys.stderr = f |
| 24 | + sys.stdout = f |
| 25 | + yield |
| 26 | + |
| 27 | + |
| 28 | +def producer(i: int, |
| 29 | + init_method: str, |
| 30 | + cuda_visible_devices: Optional[str] = None): |
| 31 | + if cuda_visible_devices is not None: |
| 32 | + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices |
| 33 | + with mute_output(): |
| 34 | + dist.init_process_group( |
| 35 | + backend="gloo", |
| 36 | + init_method=init_method, |
| 37 | + world_size=2, |
| 38 | + rank=0, |
| 39 | + ) |
| 40 | + # produce a tensor in GPU i |
| 41 | + data = torch.zeros((128, ), device=f"cuda:{i}") |
| 42 | + # get the information to reconstruct the shared tensor |
| 43 | + func, args = torch.multiprocessing.reductions.reduce_tensor(data) |
| 44 | + args = list(args) |
| 45 | + dist.broadcast_object_list([(func, args)], src=0) |
| 46 | + dist.barrier() |
| 47 | + torch.cuda.synchronize() |
| 48 | + assert torch.all(data == 1).item() |
| 49 | + |
| 50 | + |
| 51 | +def consumer(j: int, |
| 52 | + init_method: str, |
| 53 | + cuda_visible_devices: Optional[str] = None): |
| 54 | + if cuda_visible_devices is not None: |
| 55 | + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices |
| 56 | + with mute_output(): |
| 57 | + dist.init_process_group( |
| 58 | + backend="gloo", |
| 59 | + init_method=init_method, |
| 60 | + world_size=2, |
| 61 | + rank=1, |
| 62 | + ) |
| 63 | + torch.cuda.set_device(j) |
| 64 | + recv = [None] |
| 65 | + dist.broadcast_object_list(recv, src=0) |
| 66 | + func: Callable |
| 67 | + args: List |
| 68 | + func, args = recv[0] # type: ignore |
| 69 | + # `args[6]` is the device id |
| 70 | + # by default pytorch will use `i` from the producer |
| 71 | + # here we need to set it to `j` to test P2P access |
| 72 | + args[6] = j |
| 73 | + data = func(*args) |
| 74 | + data += 1 |
| 75 | + dist.barrier() |
| 76 | + torch.cuda.synchronize() |
| 77 | + assert torch.all(data == 1).item() |
| 78 | + |
| 79 | + |
| 80 | +def can_actually_p2p(i, j): |
| 81 | + """ |
| 82 | + Usually, checking if P2P access is enabled can be done by |
| 83 | + `torch.cuda.can_device_access_peer(i, j)`. However, sometimes |
| 84 | + the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)` |
| 85 | + returns `True` even if P2P access is not actually possible. |
| 86 | + See https://github.com/vllm-project/vllm/issues/2728 and |
| 87 | + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 |
| 88 | + Therefore, we have to perform a real P2P access to check if it is actually |
| 89 | + possible. |
| 90 | +
|
| 91 | + Note on p2p and cuda IPC: |
| 92 | + Usually, one process uses one GPU: |
| 93 | + GPU i --> cuda context i --> tensor i --> process i |
| 94 | +
|
| 95 | + We need to combine p2p and cuda IPC, so that: |
| 96 | + GPU i --> cuda context i --> tensor i --> process i |
| 97 | + |shared| |
| 98 | + GPU j --> cuda context j --> tensor j --> process j |
| 99 | + That is to say, process i creates a tensor in GPU i, passes IPC handle to |
| 100 | + process j, and process j accesses the tensor in GPU j. Any operation on the |
| 101 | + tensor in process j will be reflected in the tensor in process i, because |
| 102 | + they are the same memory segment. |
| 103 | + It is important to note that process j accesses the tensor in GPU j, not |
| 104 | + GPU i. That's why we need p2p access. # noqa |
| 105 | + """ |
| 106 | + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) |
| 107 | + # pass the CUDA_VISIBLE_DEVICES to the child process |
| 108 | + # to make sure they see the same set of GPUs |
| 109 | + |
| 110 | + # make sure the temp file is not the same across different calls |
| 111 | + temp_path = tempfile.mktemp() + str(time.time()) |
| 112 | + # create an empty file |
| 113 | + with open(temp_path, "w"): |
| 114 | + pass |
| 115 | + init_method = f"file://{temp_path}" |
| 116 | + |
| 117 | + # make sure the processes are spawned |
| 118 | + smp = mp.get_context("spawn") |
| 119 | + pi = smp.Process(target=producer, |
| 120 | + args=(i, init_method, cuda_visible_devices)) |
| 121 | + pj = smp.Process(target=consumer, |
| 122 | + args=(j, init_method, cuda_visible_devices)) |
| 123 | + pi.start() |
| 124 | + pj.start() |
| 125 | + pi.join() |
| 126 | + pj.join() |
| 127 | + return pi.exitcode == 0 and pj.exitcode == 0 |
| 128 | + |
| 129 | + |
| 130 | +# why do we need this cache? |
| 131 | +# we are testing peer-to-peer (p2p) access between GPUs,across processes. |
| 132 | +# if we test it every time, it will be very slow, because we need to create |
| 133 | +# N * N * 2 processes, where N is the world size. This is very slow. |
| 134 | +# to reduce the time, we use a cache file to store the p2p access status. |
| 135 | +# the cache file is generated by the master process if it does not exist. |
| 136 | +# then all the processes can read the cache file to check the p2p access status. |
| 137 | +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we |
| 138 | +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, |
| 139 | +# e.g. used by different vllm engines. The device id in the cache file is a |
| 140 | +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number |
| 141 | +# of visible devices in the vllm engine. |
| 142 | +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None |
| 143 | + |
| 144 | + |
| 145 | +def gpu_p2p_access_check(i: int, j: int) -> bool: |
| 146 | + """Check if GPU i can access GPU j.""" |
| 147 | + |
| 148 | + # if the cache variable is already calculated, |
| 149 | + # read from the cache instead of checking it again |
| 150 | + global _gpu_p2p_access_cache |
| 151 | + if _gpu_p2p_access_cache is not None: |
| 152 | + return _gpu_p2p_access_cache[f"{i}->{j}"] |
| 153 | + |
| 154 | + is_distributed = dist.is_initialized() |
| 155 | + |
| 156 | + num_dev = torch.cuda.device_count() |
| 157 | + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES |
| 158 | + if cuda_visible_devices is None: |
| 159 | + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) |
| 160 | + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT |
| 161 | + path = os.path.expanduser( |
| 162 | + f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" |
| 163 | + ) |
| 164 | + os.makedirs(os.path.dirname(path), exist_ok=True) |
| 165 | + if ((not is_distributed or get_local_rank() == 0) |
| 166 | + and (not os.path.exists(path))): |
| 167 | + # only the local master process (with local_rank == 0) can |
| 168 | + # enter this block to calculate the cache |
| 169 | + logger.info("generating GPU P2P access cache for in %s", path) |
| 170 | + cache = {} |
| 171 | + for _i in range(num_dev): |
| 172 | + for _j in range(num_dev): |
| 173 | + cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j) |
| 174 | + with open(path, "w") as f: |
| 175 | + json.dump(cache, f, indent=4) |
| 176 | + if is_distributed: |
| 177 | + cpu_world_group = get_cpu_world_group() |
| 178 | + dist.barrier(cpu_world_group) |
| 179 | + logger.info("reading GPU P2P access cache from %s", path) |
| 180 | + with open(path, "r") as f: |
| 181 | + cache = json.load(f) |
| 182 | + _gpu_p2p_access_cache = cache |
| 183 | + return _gpu_p2p_access_cache[f"{i}->{j}"] |
| 184 | + |
| 185 | + |
| 186 | +__all__ = ["gpu_p2p_access_check"] |
0 commit comments