diff --git a/tests/kv_transfer/disagg_test_pp.py b/tests/kv_transfer/disagg_test_pp.py new file mode 100644 index 000000000000..b2f2e79614cb --- /dev/null +++ b/tests/kv_transfer/disagg_test_pp.py @@ -0,0 +1,122 @@ +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch +def kill_proc(proc): + proc.terminate() + try: + proc.wait(8) + except subprocess.TimeoutExpired: + proc.kill() + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + + # Start prefill instance, testing pipeling parallelism = 2 + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--pipeline-parallel-size", + "2", + "--port", + "8100", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '"kv_rank":0,"kv_parallel_size":2}', + ] + prefill_env = os.environ.copy() + prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance, testing pipeling parallelism = 2 + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--pipeline-parallel-size", + "2", + "--port", + "8200", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '"kv_rank":1,"kv_parallel_size":2}', + ] + decode_env = os.environ.copy() + decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + kill_proc(prefill_proc) + kill_proc(decode_proc) + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 7780e2dfa317..64424bdced56 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -137,14 +137,18 @@ def select(self, input_tokens: Optional[torch.Tensor], "consumer buffer before calling select." return self.consumer_buffer.drop_select(input_tokens, roi) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: + def insert(self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + residual: torch.Tensor = None) -> None: assert self.producer_buffer is not None, "Please initialize the "\ "producer buffer before calling insert." - - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + self.producer_buffer.insert(input_tokens, roi, key, value, hidden, + residual) def send_kv_caches_and_hidden_states( self, @@ -190,12 +194,20 @@ def send_kv_caches_and_hidden_states( keys = torch.cat(keys, dim=0) values = torch.cat(values, dim=0) - - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) - + if isinstance(hidden_or_intermediate_states, torch.Tensor): + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + elif isinstance(hidden_or_intermediate_states, + IntermediateTensors): + self.insert( + current_tokens, torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states["hidden_states"] + [start_pos:end_pos], + hidden_or_intermediate_states["residual"] + [start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) def recv_kv_caches_and_hidden_states( @@ -216,7 +228,7 @@ def recv_kv_caches_and_hidden_states( slot_mapping = model_input.attn_metadata.slot_mapping.flatten() hidden_or_intermediate_states_for_one_req = [] - + pp_intermediate_tensors = IntermediateTensors({}) input_tokens_list = [] num_computed_tokens_list = [] start_pos_list = [] @@ -246,6 +258,9 @@ def recv_kv_caches_and_hidden_states( keys: torch.Tensor = ret[2] values: torch.Tensor = ret[3] hidden: torch.Tensor = ret[4] + residual: torch.Tensor = None + if len(ret) > 5: + residual = ret[5] num_computed_tokens = roi.shape[0] num_computed_tokens_list.append(num_computed_tokens) @@ -279,8 +294,22 @@ def recv_kv_caches_and_hidden_states( layer.self_attn.attn._k_scale, layer.self_attn.attn._v_scale, ) - - hidden_or_intermediate_states_for_one_req.append(hidden) + if residual is None: + hidden_or_intermediate_states_for_one_req.append(hidden) + else: + if "hidden_states" not in pp_intermediate_tensors.tensors: + pp_intermediate_tensors.tensors["hidden_states"] = hidden + pp_intermediate_tensors.tensors["residual"] = residual + else: + pp_intermediate_tensors.tensors[ + "hidden_states"] = torch.cat( + (pp_intermediate_tensors.tensors["hidden_states"], + hidden), + dim=1) + pp_intermediate_tensors.tensors["residual"] = torch.cat( + (pp_intermediate_tensors.tensors["residual"], + residual), + dim=1) if not bypass_model_exec: # Some of the KV cache is not retrieved @@ -296,8 +325,11 @@ def recv_kv_caches_and_hidden_states( logger.debug( "[rank%d]: Successfully received all KVs and hidden " "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) + if residual is None: + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + else: + hidden_or_intermediate_states = pp_intermediate_tensors return hidden_or_intermediate_states, bypass_model_exec, model_input diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index fe8d8d7375f3..b600450f4901 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -15,6 +15,7 @@ import torch +from vllm.distributed import get_pp_group from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( KVLookupBufferBase) from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase @@ -99,9 +100,13 @@ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): + def _add_to_buffer(self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + residual: torch.Tensor = None): if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -113,8 +118,11 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, value = value.clone() if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - - buffer_item = [input_tokens, roi, key, value, hidden] + if residual is not None: + residual = residual.clone() + buffer_item = [input_tokens, roi, key, value, hidden, residual] + else: + buffer_item = [input_tokens, roi, key, value, hidden] with self.buffer_lock: for data in buffer_item: @@ -204,15 +212,22 @@ def drop_select( key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() + if get_pp_group().world_size > 1 and not get_pp_group().is_last_rank: + residual = self.data_pipe.recv_tensor() + return [input_tokens, roi, key, value, hidden, residual] return [input_tokens, roi, key, value, hidden] def full_handler(self): time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: + def insert(self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + residual: torch.Tensor = None) -> None: if self.buffer_size > self.buffer_size_threshold: # log outside the while loop to avoid this message being logged @@ -221,7 +236,7 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, while self.buffer_size > self.buffer_size_threshold: self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) + self._add_to_buffer(input_tokens, roi, key, value, hidden, residual) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request.