Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions tests/kv_transfer/disagg_test_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import subprocess
import sys
import time
from subprocess import Popen

import pytest
import requests
import torch
def kill_proc(proc):
proc.terminate()
try:
proc.wait(8)
except subprocess.TimeoutExpired:
proc.kill()

# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 4:
pytest.skip("Skipping test: fewer than 4 GPUs available")

# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",
shell=True).decode().strip()
os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP

# Start prefill instance, testing pipeling parallelism = 2
prefill_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--pipeline-parallel-size",
"2",
"--port",
"8100",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\
'"kv_rank":0,"kv_parallel_size":2}',
]
prefill_env = os.environ.copy()
prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1"
prefill_proc = Popen(prefill_cmd, env=prefill_env)

# Start decode instance, testing pipeling parallelism = 2
decode_cmd = [
sys.executable,
"-m",
"vllm.entrypoints.openai.api_server",
"--model",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"--pipeline-parallel-size",
"2",
"--port",
"8200",
"--gpu-memory-utilization",
"0.5",
"--max-model-len",
"1000",
"--kv-transfer-config",
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\
'"kv_rank":1,"kv_parallel_size":2}',
]
decode_env = os.environ.copy()
decode_env["CUDA_VISIBLE_DEVICES"] = "2,3"
decode_proc = Popen(decode_cmd, env=decode_env)

# Wait for servers to be ready
assert wait_for_server(8100), "Prefill server did not start in time"
assert wait_for_server(8200), "Decode server did not start in time"

# Yield to the test function and handle teardown after tests
yield
kill_proc(prefill_proc)
kill_proc(decode_proc)


# Helper function to wait for server
def wait_for_server(port, timeout=240):
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/v1/completions")
if response.status_code in [200, 405]:
return True
except requests.ConnectionError:
time.sleep(1)
return False


# Test function to send curl requests and validate responses
@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"])
def test_disaggregated_prefilling(prompt):
# Send to prefill
response = requests.post("http://localhost:8100/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 1,
"temperature": 0
})
assert response.status_code == 200

# Send to decode
response = requests.post("http://localhost:8200/v1/completions",
headers={"Content-Type": "application/json"},
json={
"model":
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": prompt,
"max_tokens": 10,
"temperature": 0
})
assert response.status_code == 200
64 changes: 48 additions & 16 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,18 @@ def select(self, input_tokens: Optional[torch.Tensor],
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)

def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
def insert(self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
residual: torch.Tensor = None) -> None:

assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."

self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
self.producer_buffer.insert(input_tokens, roi, key, value, hidden,
residual)

def send_kv_caches_and_hidden_states(
self,
Expand Down Expand Up @@ -190,12 +194,20 @@ def send_kv_caches_and_hidden_states(

keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)

self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])

if isinstance(hidden_or_intermediate_states, torch.Tensor):
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
elif isinstance(hidden_or_intermediate_states,
IntermediateTensors):
self.insert(
current_tokens, torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states["hidden_states"]
[start_pos:end_pos],
hidden_or_intermediate_states["residual"]
[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())

def recv_kv_caches_and_hidden_states(
Expand All @@ -216,7 +228,7 @@ def recv_kv_caches_and_hidden_states(
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()

hidden_or_intermediate_states_for_one_req = []

pp_intermediate_tensors = IntermediateTensors({})
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
Expand Down Expand Up @@ -246,6 +258,9 @@ def recv_kv_caches_and_hidden_states(
keys: torch.Tensor = ret[2]
values: torch.Tensor = ret[3]
hidden: torch.Tensor = ret[4]
residual: torch.Tensor = None
if len(ret) > 5:
residual = ret[5]

num_computed_tokens = roi.shape[0]
num_computed_tokens_list.append(num_computed_tokens)
Expand Down Expand Up @@ -279,8 +294,22 @@ def recv_kv_caches_and_hidden_states(
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)

hidden_or_intermediate_states_for_one_req.append(hidden)
if residual is None:
hidden_or_intermediate_states_for_one_req.append(hidden)
else:
if "hidden_states" not in pp_intermediate_tensors.tensors:
pp_intermediate_tensors.tensors["hidden_states"] = hidden
pp_intermediate_tensors.tensors["residual"] = residual
else:
pp_intermediate_tensors.tensors[
"hidden_states"] = torch.cat(
(pp_intermediate_tensors.tensors["hidden_states"],
hidden),
dim=1)
pp_intermediate_tensors.tensors["residual"] = torch.cat(
(pp_intermediate_tensors.tensors["residual"],
residual),
dim=1)

if not bypass_model_exec:
# Some of the KV cache is not retrieved
Expand All @@ -296,8 +325,11 @@ def recv_kv_caches_and_hidden_states(
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
if residual is None:
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
else:
hidden_or_intermediate_states = pp_intermediate_tensors

return hidden_or_intermediate_states, bypass_model_exec, model_input

Expand Down
33 changes: 24 additions & 9 deletions vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch

from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVLookupBufferBase)
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
Expand Down Expand Up @@ -99,9 +100,13 @@ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):

raise AssertionError(f"Unknown data type {type(data)}")

def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
def _add_to_buffer(self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
residual: torch.Tensor = None):

if isinstance(input_tokens, torch.Tensor):
input_tokens = input_tokens.clone()
Expand All @@ -113,8 +118,11 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
value = value.clone()
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()

buffer_item = [input_tokens, roi, key, value, hidden]
if residual is not None:
residual = residual.clone()
buffer_item = [input_tokens, roi, key, value, hidden, residual]
else:
buffer_item = [input_tokens, roi, key, value, hidden]

with self.buffer_lock:
for data in buffer_item:
Expand Down Expand Up @@ -204,15 +212,22 @@ def drop_select(
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
if get_pp_group().world_size > 1 and not get_pp_group().is_last_rank:
residual = self.data_pipe.recv_tensor()
return [input_tokens, roi, key, value, hidden, residual]

return [input_tokens, roi, key, value, hidden]

def full_handler(self):
time.sleep(0.001)

def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
def insert(self,
input_tokens: torch.Tensor,
roi: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
hidden: torch.Tensor,
residual: torch.Tensor = None) -> None:

if self.buffer_size > self.buffer_size_threshold:
# log outside the while loop to avoid this message being logged
Expand All @@ -221,7 +236,7 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()

self._add_to_buffer(input_tokens, roi, key, value, hidden)
self._add_to_buffer(input_tokens, roi, key, value, hidden, residual)

# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
Expand Down
Loading