Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions examples/online_serving/disaggregated_prefill_v1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.

set -xe

echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}

export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')

# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi

# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}


# You can also adjust --kv-ip and --kv-port for distributed inference.

# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' &

# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &

# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200

# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (port 8100), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance
# NOTE: the usage of this API is subject to change --- in the future we will
# introduce "vllm connect" to connect between prefill and decode instances
python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
sleep 1

# serve two example requests
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')

output2=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')


# Cleanup commands
pgrep python | xargs kill -9
pkill -f python

echo ""

sleep 1

# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"

echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
120 changes: 120 additions & 0 deletions tests/kv_transfer/disagg_test_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0

import os
import subprocess
import sys
import time
from subprocess import Popen

import pytest
import requests
import torch


# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 4:
pytest.skip("Skipping test: fewer than 4 GPUs available")

# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
shell=True).decode().strip()
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP
os.environ["VLLM_USE_V1"] = "1"
# Start prefill instance
prefill_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Llama-3.2-1B-Instruct",
"--port",
"8100",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
'"kv_rank":0,"kv_parallel_size":2}',
]
prefill_env = os.environ.copy()
prefill_env["CUDA_VISIBLE_DEVICES"] = "0"
prefill_proc = Popen(prefill_cmd, env=prefill_env)

os.environ["VLLM_USE_V1"] = "1"
# Start decode instance
decode_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Llama-3.2-1B-Instruct",
"--port",
"8200",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
'"kv_rank":1,"kv_parallel_size":2}',
]
decode_env = os.environ.copy()
decode_env["CUDA_VISIBLE_DEVICES"] = "1"
decode_proc = Popen(decode_cmd, env=decode_env)

# Wait for servers to be ready
assert wait_for_server(8100), "Prefill server did not start in time"
assert wait_for_server(8200), "Decode server did not start in time"

# Yield to the test function and handle teardown after tests
yield

# Cleanup: kill the processes
prefill_proc.terminate()
decode_proc.terminate()

# Additional cleanup if needed
prefill_proc.wait()
decode_proc.wait()


# Helper function to wait for server
def wait_for_server(port, timeout=240):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/v1/completions")
if response.status_code in [200, 405]:
return True
except requests.ConnectionError:
time.sleep(1)
return False


# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
# Send to prefill
response = requests.post("http://localhost:8100/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model": "meta-llama/Llama-3.2-1B-Instruct",
"prompt": prompt,
"max_tokens": 1,
"temperature": 0
})
assert response.status_code == 200

# Send to decode
response = requests.post("http://localhost:8200/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model": "meta-llama/Llama-3.2-1B-Instruct",
"prompt": prompt,
"max_tokens": 10,
"temperature": 0
})
assert response.status_code == 200
13 changes: 8 additions & 5 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@
import torch.distributed
from torch.distributed import Backend, ProcessGroup

import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs

if envs.VLLM_USE_V1:
import vllm.v1.distributed.kv_transfer.kv_transfer_agent as kv_transfer
else:
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer

from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.utils import StatelessProcessGroup
Expand Down Expand Up @@ -918,10 +923,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if vllm_config.kv_transfer_config is None:
return

if all([
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
is None
]):
if all([vllm_config.kv_transfer_config.kv_connector, _KV_TRANSFER
is None]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
Expand Down
Empty file.
124 changes: 124 additions & 0 deletions vllm/v1/distributed/kv_transfer/kv_connector/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication

The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union

import torch

from vllm.sequence import IntermediateTensors

if TYPE_CHECKING:
from vllm.config import VllmConfig


class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.

The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""

@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError

@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.

This method is responsible for cleaning up resources related to the
connector when it is no longer needed.

Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model: torch.nn.Module,
input_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
attn_metadata,
) -> None:
"""
Send KV caches and hidden states to the connector.

This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.

Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (List[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.

Returns:
None

"""

raise NotImplementedError

@abstractmethod
def recv_kv_caches_and_hidden_states(
self,
model: torch.nn.Module,
input_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata,
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]:
"""
Receive KV caches and hidden states from the connector.

This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.

Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (List[torch.Tensor]):
List of KV caches for each layer.

Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.

"""

raise NotImplementedError
Loading