From fab251aa9b4e53a4fd70bd7a80049ee6e466c6e8 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 28 Aug 2025 02:25:33 +0000 Subject: [PATCH 1/9] update nixl_connector with backend option Signed-off-by: Chendi Xue --- .../kv_connector/v1/nixl_connector.py | 23 +++++++++++++------ vllm/envs.py | 4 ++++ vllm/platforms/interface.py | 15 ++++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d3a08af088c1..1bf26bc1d31d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -53,6 +53,7 @@ # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: from nixl._api import nixl_agent as NixlWrapper + from nixl._api import nixl_agent_config logger.info("NIXL is available") except ImportError: logger.warning("NIXL is not available") @@ -65,6 +66,8 @@ "tpu": ("cpu", ), "xpu": ("cpu", ), } +# support for oot platform by providing mapping in current_platform +_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) class NixlAgentMetadata( @@ -448,8 +451,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + self.nixl_backend = envs.VLLM_NIXL_BACKEND # Agent. - self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + config = nixl_agent_config(backends=[self.nixl_backend]) + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -486,11 +491,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.use_host_buffer = self.kv_buffer_device == "cpu" - if self.kv_buffer_device == "cuda": - self.nixl_memory_type = "VRAM" - elif self.kv_buffer_device == "cpu": - self.nixl_memory_type = "DRAM" - else: + # support for oot platform which can't register nixl memory + # type based on kv_buffer_device + self.nixl_memory_type = current_platform.get_nixl_memory_type() + if self.nixl_memory_type is None: + if self.kv_buffer_device == "cuda": + self.nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + self.nixl_memory_type = "DRAM" + if self.nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported.") @@ -766,7 +775,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs) + self.nixl_wrapper.register_memory(descs, backends=[self.nixl_backend]) logger.debug("Done registering descs") self._registered_descs.append(descs) diff --git a/vllm/envs.py b/vllm/envs.py index eaee2f6cc771..678c5647e611 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1136,6 +1136,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + # Backend for vllm's NIXL communication. + "VLLM_NIXL_BACKEND": + lambda: os.getenv("VLLM_NIXL_BACKEND", "UCX"), + # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index cad04ea14c01..bd2b9391a14f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -618,6 +618,21 @@ def _synced_weight_loader(param, *args, **kwargs): return _synced_weight_loader + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return None + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED From 22734d0d4578f390cc08c87421fdd47444ec0232 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 19 Sep 2025 00:52:28 +0000 Subject: [PATCH 2/9] use kv_connector_extra_config instead of env for backend setting Signed-off-by: Chendi Xue --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 +++- vllm/envs.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1bf26bc1d31d..06f35e864d4d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -451,7 +451,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - self.nixl_backend = envs.VLLM_NIXL_BACKEND + self.nixl_backend = \ + vllm_config.kv_transfer_config.get_from_extra_config( + "backend", "UCX") # Agent. config = nixl_agent_config(backends=[self.nixl_backend]) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) diff --git a/vllm/envs.py b/vllm/envs.py index 678c5647e611..eaee2f6cc771 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1136,10 +1136,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), - # Backend for vllm's NIXL communication. - "VLLM_NIXL_BACKEND": - lambda: os.getenv("VLLM_NIXL_BACKEND", "UCX"), - # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts From 95540a06d39e4dade3d917bf5017106213823aa3 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 19 Sep 2025 01:04:25 +0000 Subject: [PATCH 3/9] update doc with backend option Signed-off-by: Chendi Xue --- docs/features/disagg_prefill.md | 2 +- docs/serving/expert_parallel_deployment.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 69f70b8ff5ac..fdcbeb4afe6d 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -28,7 +28,7 @@ Now supports 5 types of connectors: - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: ```bash - --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' + --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage", "backend":"UCX"}}]}}' ``` - **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 7489fc260983..151467276d63 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok 1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. -2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}` +2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":"UCX"}}` 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. From a8377a7d7d001a5c42cc34df7d309675df9ca68f Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 19 Sep 2025 01:54:32 +0000 Subject: [PATCH 4/9] update backend to backends Signed-off-by: Chendi Xue --- docs/features/disagg_prefill.md | 8 +++++++- docs/serving/expert_parallel_deployment.md | 2 +- .../kv_transfer/kv_connector/v1/nixl_connector.py | 8 ++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index fdcbeb4afe6d..cb62213cc7af 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -28,7 +28,13 @@ Now supports 5 types of connectors: - **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as: ```bash - --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage", "backend":"UCX"}}]}}' + --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' + ``` + +For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as: + + ```bash + --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}' ``` - **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): diff --git a/docs/serving/expert_parallel_deployment.md b/docs/serving/expert_parallel_deployment.md index 151467276d63..f823d33df80e 100644 --- a/docs/serving/expert_parallel_deployment.md +++ b/docs/serving/expert_parallel_deployment.md @@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok 1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip. -2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":"UCX"}}` +2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'` 3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 06f35e864d4d..2e7b3a144030 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -451,11 +451,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - self.nixl_backend = \ + self.nixl_backends = \ vllm_config.kv_transfer_config.get_from_extra_config( - "backend", "UCX") + "backends", ["UCX"]) # Agent. - config = nixl_agent_config(backends=[self.nixl_backend]) + config = nixl_agent_config(backends=self.nixl_backends) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) @@ -777,7 +777,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) - self.nixl_wrapper.register_memory(descs, backends=[self.nixl_backend]) + self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends) logger.debug("Done registering descs") self._registered_descs.append(descs) From fe500025e81a294326a9ae191dc767546b5ca3f6 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 19 Sep 2025 03:10:06 +0000 Subject: [PATCH 5/9] Add UT to test kv_buffer_device to nixl_memory_type mapping Signed-off-by: Chendi Xue --- .../kv_connector/unit/test_nixl_connector.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 6e58d158c3f4..077e49586653 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -27,6 +27,7 @@ KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker, NixlKVConnectorStats) from vllm.forward_context import ForwardContext +from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -855,3 +856,52 @@ def test_register_kv_caches(dist_init): assert block_len == expected_block_len, \ f"Block entry {i}: Expected block len {expected_block_len}, " \ f"got {block_len}" + + +class FakePlatform(Platform): + device_type: str = "oot" + + @classmethod + def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: + """ + Returns a mapping from device_type to a tuple of supported + kv_buffer_device for nixl. + """ + return {'oot': ('oot')} + + @classmethod + def get_nixl_memory_type(cls) -> Optional[str]: + """ + Returns the nixl memory type for the current platform. + """ + return 'VRAM' + + +@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [ + ("oot", "VRAM"), +]) +def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, + nixl_memory_type): + """ + Test that register_kv_caches() passes the correct memory types from the + config to the nixl_wrapper. + """ + vllm_config = create_vllm_config() + # Override the default memory types in the config + vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device + from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + _NIXL_SUPPORTED_DEVICE) + _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501 + + # Create connector and replace its worker with a fake one for isolation + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + + # Verify get_reg_descs was called with the correct memory_type + assert connector.connector_worker.kv_buffer_device == kv_buffer_device + assert connector.connector_worker.nixl_memory_type == nixl_memory_type From 7e467cc95fcf305a315212b545e541b7a7c17e4a Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 19 Sep 2025 03:33:55 +0000 Subject: [PATCH 6/9] Fix mypy Signed-off-by: Chendi Xue --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 077e49586653..2c74f317bb05 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -867,7 +867,7 @@ def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: Returns a mapping from device_type to a tuple of supported kv_buffer_device for nixl. """ - return {'oot': ('oot')} + return {'oot': ('oot', )} @classmethod def get_nixl_memory_type(cls) -> Optional[str]: From 77147bc0a2ce07b7692d8fd888e7d8a06d71fa05 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 22 Sep 2025 19:54:42 +0000 Subject: [PATCH 7/9] retrigger CI Signed-off-by: Chendi Xue From 5a78f4888d1e80c46e927701908fd49aff7fcc64 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Mon, 22 Sep 2025 22:55:34 +0000 Subject: [PATCH 8/9] fix CI for missing of nixl_agent_config Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2e7b3a144030..82b483447e33 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -53,12 +53,17 @@ # Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used try: from nixl._api import nixl_agent as NixlWrapper - from nixl._api import nixl_agent_config logger.info("NIXL is available") except ImportError: logger.warning("NIXL is not available") NixlWrapper = None +try: + from nixl._api import nixl_agent_config +except ImportError: + nixl_agent_config = None + logger.warning("NIXL agent config is not available") + # Supported platforms and types of kv transfer buffer. # {device: tuple of supported kv buffer types} _NIXL_SUPPORTED_DEVICE = { @@ -455,7 +460,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"]) # Agent. - config = nixl_agent_config(backends=self.nixl_backends) + non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] + config = nixl_agent_config(backends=self.nixl_backends) if len( + non_ucx_backends) > 0 and nixl_agent_config is not None else None + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) From 5b0745c1ddeba08e61ce34270f65a871a1a84378 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Tue, 23 Sep 2025 02:03:07 +0000 Subject: [PATCH 9/9] Fix CI Signed-off-by: Chendi Xue --- tests/v1/kv_connector/unit/test_nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 2c74f317bb05..fa698a2eabd9 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -57,7 +57,7 @@ def __init__(self, agent_name: str, *args, **kwargs): def get_reg_descs(self, caches_data, memory_type: str) -> list: return [str(uuid.uuid4()) for _ in caches_data] - def register_memory(self, descs) -> None: + def register_memory(self, descs, backends) -> None: pass def get_xfer_descs(self, blocks_data, memory_type: str) -> list: