From 901c3f2b282f7b0705c2cf3fa113c42803546ed7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 27 Sep 2024 10:38:43 -0700 Subject: [PATCH 1/2] enable skip --- .../distributed/device_communicators/custom_all_reduce.py | 2 ++ vllm/envs.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index d239d645edc1..8c98ca89b284 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -28,6 +28,8 @@ def _can_p2p(rank: int, world_size: int) -> bool: for i in range(world_size): if i == rank: continue + if envs.VLLM_SKIP_P2P_CHECK: + return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False return True diff --git a/vllm/envs.py b/vllm/envs.py index 705d858e71a6..7cbffc83a625 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -63,6 +63,7 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False + VLLM_SKIP_P2P_CHECK: bool = False def get_default_cache_root(): @@ -423,6 +424,13 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in ("1", "true")), + + # By default, vLLM will check the peer-to-peer capability itself, + # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa + # If this env var is set to 1, vLLM will skip the peer-to-peer check, + # and trust the driver's peer-to-peer capability report. + "VLLM_SKIP_P2P_CHECK": + lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", } # end-env-vars-definition From a854d5d903aac056362735d1e3ae77c145f7f806 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 27 Sep 2024 10:41:38 -0700 Subject: [PATCH 2/2] add logging --- vllm/distributed/device_communicators/custom_all_reduce.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 8c98ca89b284..c95192a5a1bc 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -29,6 +29,8 @@ def _can_p2p(rank: int, world_size: int) -> bool: if i == rank: continue if envs.VLLM_SKIP_P2P_CHECK: + logger.info( + "Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False