diff --git a/examples/online_serving/disaggregated_prefill_v1.sh b/examples/online_serving/disaggregated_prefill_v1.sh new file mode 100644 index 000000000000..0b21a6c5eb8c --- /dev/null +++ b/examples/online_serving/disaggregated_prefill_v1.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +set -xe + +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +# You can also adjust --kv-ip and --kv-port for distributed inference. + +# prefilling instance, which is the KV producer +CUDA_VISIBLE_DEVICES=0 VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + +# decoding instance, which is the KV consumer +CUDA_VISIBLE_DEVICES=1 VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance +# NOTE: the usage of this API is subject to change --- in the future we will +# introduce "vllm connect" to connect between prefill and decode instances +python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + + +# Cleanup commands +pgrep python | xargs kill -9 +pkill -f python + +echo "" + +sleep 1 + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" +echo "" diff --git a/tests/kv_transfer/disagg_test_v1.py b/tests/kv_transfer/disagg_test_v1.py new file mode 100644 index 000000000000..d6db1651687f --- /dev/null +++ b/tests/kv_transfer/disagg_test_v1.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + os.environ["VLLM_USE_V1"] = "1" + # Start prefill instance + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Llama-3.2-1B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '"kv_rank":0,"kv_parallel_size":2}', + ] + prefill_env = os.environ.copy() + prefill_env["CUDA_VISIBLE_DEVICES"] = "0" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + os.environ["VLLM_USE_V1"] = "1" + # Start decode instance + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Llama-3.2-1B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '"kv_rank":1,"kv_parallel_size":2}', + ] + decode_env = os.environ.copy() + decode_env["CUDA_VISIBLE_DEVICES"] = "1" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 781f870a756c..cccfb4c97561 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,8 +37,13 @@ import torch.distributed from torch.distributed import Backend, ProcessGroup -import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs + +if envs.VLLM_USE_V1: + import vllm.v1.distributed.kv_transfer.kv_transfer_agent as kv_transfer +else: + import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer + from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) from vllm.distributed.utils import StatelessProcessGroup @@ -918,10 +923,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if vllm_config.kv_transfer_config is None: return - if all([ - vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER - is None - ]): + if all([vllm_config.kv_transfer_config.kv_connector, _KV_TRANSFER + is None]): _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, diff --git a/vllm/v1/distributed/kv_transfer/kv_connector/__init__.py b/vllm/v1/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/distributed/kv_transfer/kv_connector/base.py b/vllm/v1/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 000000000000..aea272c0d430 --- /dev/null +++ b/vllm/v1/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + attn_metadata, + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata, + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError diff --git a/vllm/v1/distributed/kv_transfer/kv_connector/factory.py b/vllm/v1/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 000000000000..28012278e6e0 --- /dev/null +++ b/vllm/v1/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from typing import TYPE_CHECKING, Callable, Dict, Type + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + _registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> Type[KVConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + return connector_cls(rank, local_rank, config) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.v1.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") diff --git a/vllm/v1/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/v1/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 000000000000..af7dc1e94fdf --- /dev/null +++ b/vllm/v1/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.v1.distributed.kv_transfer.kv_connector.base import KVConnectorBase + +if TYPE_CHECKING: + pass + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + + if self.config.kv_connector == "PyNcclConnector": + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + logger.info( + "Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + else: + raise NotImplementedError( + "Only PyNcclConnector is supported for now.") + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + self.producer_data_pipe: PyNcclPipe + self.consumer_data_pipe: PyNcclPipe + self.consumer_signal_pipe: PyNcclPipe + self.producer_signal_pipe: PyNcclPipe + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + if self.config.kv_connector == "PyNcclConnector": + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + else: + raise NotImplementedError( + "Only PyNcclConnector is supported for producer.") + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + if self.config.kv_connector == "PyNcclConnector": + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + else: + raise NotImplementedError( + "Only PyNcclConnector is supported for consumer.") + + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + attn_metadata, + ) -> None: + + input_tokens_tensor = input_ids + seq_lens = attn_metadata.seq_lens + slot_mapping_flat = attn_metadata.slot_mapping.flatten() + start_layer = model.model.start_layer + end_layer = model.model.end_layer + + model_config = model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME: This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, + model: torch.nn.Module, + input_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata, + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = input_ids + seq_lens = attn_metadata.seq_lens + slot_mapping = attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME: This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model.model.start_layer, model.model.end_layer): + + kv_cache = kv_caches[i - model.model.start_layer] + layer = model.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model.model.start_layer].to(key_cache.device), + values[i - model.model.start_layer].to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec + + def close(self): + self.producer_data_pipe.close() + self.consumer_data_pipe.close() + if self.config.kv_connector == "PyNcclConnector": + self.producer_signal_pipe.close() + self.consumer_signal_pipe.close() + else: + raise NotImplementedError( + "Only PyNcclConnector is supported for now.") + logger.info("SimpleConnector closed.") diff --git a/vllm/v1/distributed/kv_transfer/kv_transfer_agent.py b/vllm/v1/distributed/kv_transfer/kv_transfer_agent.py new file mode 100644 index 000000000000..7d0718677df6 --- /dev/null +++ b/vllm/v1/distributed/kv_transfer/kv_transfer_agent.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.v1.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + self.connector = KVConnectorFactory.create_connector( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, model: torch.nn.Module, + input_ids: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + attn_metadata) -> None: + self.connector.send_kv_caches_and_hidden_states( + model, input_ids, kv_caches, hidden_or_intermediate_states, + attn_metadata) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, + model: torch.nn.Module, + input_ids: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + attn_metadata, + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + return self.connector.recv_kv_caches_and_hidden_states( + model, input_ids, kv_caches, attn_metadata) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1212c3554b6..4685ef0d6ad1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,7 +12,8 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig -from vllm.distributed.parallel_state import get_pp_group, graph_capture +from vllm.distributed.parallel_state import (get_kv_transfer_group, + get_pp_group, graph_capture) from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -867,6 +868,50 @@ def _gather_encoder_outputs( def get_model(self) -> nn.Module: return self.model + def need_recv_kv(self, attn_metadata, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + if self.vllm_config.kv_transfer_config is None: + return False + + is_prefill_run = attn_metadata.num_actual_tokens > 1 + + # # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + return (self.vllm_config.kv_transfer_config.is_kv_consumer + and not is_profile_run) and is_prefill_run + + def need_send_kv(self, attn_metadata, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = attn_metadata.num_actual_tokens > 1 + + return (self.vllm_config.kv_transfer_config.is_kv_producer + and not is_profile_run) and is_prefill_run + @torch.inference_mode() def execute_model( self, @@ -929,22 +974,41 @@ def execute_model( k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items() }) - + bypass_model_exec = False + if self.need_recv_kv(attn_metadata, self.kv_caches): + hidden_states, bypass_model_exec = \ + get_kv_transfer_group().recv_kv_caches_and_hidden_states( + self.model, + input_ids, + kv_caches=self.kv_caches, + attn_metadata=attn_metadata, + ) # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - kv_caches=self.kv_caches, - attn_metadata=None, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + if bypass_model_exec is False: + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=self.kv_caches, + attn_metadata=None, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. return hidden_states - + if self.need_send_kv(attn_metadata, self.kv_caches): + get_kv_transfer_group().send_kv_caches_and_hidden_states( + # model is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + self.model, + input_ids, + self.kv_caches, + hidden_states, + attn_metadata, + ) hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 10154a752393..587116d3f2f8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -11,7 +11,8 @@ import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.device_allocator.cumem import CuMemAllocator -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -109,7 +110,8 @@ def init_device(self): raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, + self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -244,6 +246,7 @@ def check_health(self) -> None: def init_worker_distributed_environment( + vllm_config: VllmConfig, parallel_config: ParallelConfig, rank: int, distributed_init_method: Optional[str] = None, @@ -257,6 +260,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):