diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b..148cb9e9b080 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -125,9 +125,9 @@ def sample_sonnet_requests( prefix_len: int, tokenizer: PreTrainedTokenizerBase, ) -> List[Tuple[str, str, int, int]]: - assert ( - input_len > prefix_len - ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." + assert input_len >= prefix_len, ( + "'args.sonnet-input-len' must be greater than or equal to " + "'args.prefix-input-len'.") # Load the dataset. with open(dataset_path) as f: diff --git a/benchmarks/disagg_benchmarks/analyze_benchmark_results.py b/benchmarks/disagg_benchmarks/analyze_benchmark_results.py new file mode 100644 index 000000000000..4b675c675d25 --- /dev/null +++ b/benchmarks/disagg_benchmarks/analyze_benchmark_results.py @@ -0,0 +1,48 @@ + +import argparse +import json +import yaml +import os +from pathlib import Path + +def load(path): + + with open(str(path), 'r') as f: + return json.loads(f.read()) + +def main(args): + + results = Path(args.results_folder) + + chunk = load(results / "chunked_prefill_tp4.json") + prefill = load(results / "disagg_prefill_tp4.json") + decode = load(results / "disagg_decode_tp4.json") + + ttft_ratio = chunk["mean_ttft_ms"] / prefill["mean_ttft_ms"] + itl_ratio = chunk["mean_itl_ms"] / decode["mean_itl_ms"] + prefill_decode_ratio = prefill["mean_ttft_ms"] / (decode["mean_itl_ms"] * args.output_len) + + with open(results / args.output_file, 'a') as f: + f.write(yaml.dump([{ + 'qps': args.qps, + 'output_len': args.output_len, + 'prefill_decode_ratio': prefill_decode_ratio, + 'ttft_ratio': ttft_ratio, + 'itl_ratio': itl_ratio, + "chunk_ttft": chunk["mean_ttft_ms"], + "chunk_itl": chunk["mean_itl_ms"], + "disagg_ttft": prefill["mean_ttft_ms"], + "disagg_itl": decode["mean_itl_ms"] + }])) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze benchmark results") + parser.add_argument("--results-folder", required=True, help="Path to the results folder") + parser.add_argument("--output-len", type=int, required=True, help="Target output length") + parser.add_argument("--qps", type=int, required=True, help="Target QPS") + parser.add_argument("--output-file", type=str, default="chunk_vs_disagg.yaml") + + args = parser.parse_args() + main(args) + \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 000000000000..12f5150cadda --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,148 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + export VLLM_PORT=12345 + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=50 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + # large model + VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --gpu-memory-utilization 0.8 & + VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --gpu-memory-utilization 0.8 & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 000000000000..dde9a80b59b3 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill -f pt_main_thread + pkill -f python3 + pkill -f round_robin_proxy.sh + ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + # disagg prefill + VLLM_RPC_PORT=5570 CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + VLLM_RPC_PORT=5580 CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + wait_for_server 8100 + wait_for_server 8200 + bash round_robin_proxy.sh & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + # disagg prefill + VLLM_PORT=12345 VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + VLLM_PORT=12345 VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 30000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=400 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename $tag-qps-$qps.json \ + --request-rate $qps + + sleep 2 + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=10 + default_output_len=150 + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 000000000000..5750df7735ad --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,55 @@ +from quart import Quart, request, Response, jsonify, make_response +import aiohttp +import sys +import traceback +import os + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + content = await response.read() + yield content + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', prefill_request): + continue + + print(f"Prefill done. proceeding to decode.") + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + pass + # exc_info = sys.exc_info() + # print(e) + # print("".join(traceback.format_exception(*exc_info))) + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.sh b/benchmarks/disagg_benchmarks/round_robin_proxy.sh new file mode 100644 index 000000000000..375bf9e42237 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Define the ports to forward to +PORTS=(8100 8200) +NUM_PORTS=${#PORTS[@]} +CURRENT=0 + +# Function to handle the round-robin logic +get_next_port() { + NEXT_PORT=${PORTS[$CURRENT]} + CURRENT=$(( (CURRENT + 1) % NUM_PORTS )) + echo $NEXT_PORT +} + +# Start the proxy +while true; do + NEXT_PORT=$(get_next_port) + socat TCP4-LISTEN:8000,reuseaddr,fork TCP4:localhost:$NEXT_PORT 2>/dev/null +done \ No newline at end of file diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 000000000000..192f26a1e3cd --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,47 @@ + +import matplotlib.pyplot as plt +import yaml +import pandas as pd +import json + + + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2,4,6,8]: + with open(f"results/{name}-qps-{qps}.json", "r") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + + for key in ['mean_ttft_ms', + 'median_ttft_ms', + 'p99_ttft_ms', + 'mean_itl_ms', + 'median_itl_ms', + 'p99_itl_ms']: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], dis_df[key], label='disagg_prefill', marker='o', linewidth=4) + plt.plot(chu_df['qps'], chu_df[key], label='chunked_prefill', marker='o', linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) + + + \ No newline at end of file diff --git a/examples/disagg_prefill/disagg_prefill_example.sh b/examples/disagg_prefill/disagg_prefill_example.sh new file mode 100644 index 000000000000..f57f5fd86d89 --- /dev/null +++ b/examples/disagg_prefill/disagg_prefill_example.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') +export VLLM_PORT=12345 + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# prefilling instance +VLLM_RPC_PORT=5570 VLLM_DISAGG_PREFILL_ROLE=prefill CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# decoding instance +VLLM_RPC_PORT=5580 VLLM_DISAGG_PREFILL_ROLE=decode CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve an example request +curl http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}' + +# clean up +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 \ No newline at end of file diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 26b3159682b3..c886b55bf8b6 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -15,6 +15,9 @@ is_block_tables_empty) from vllm.utils import make_tensor_with_pad +from vllm.distributed import get_disagg_group +import vllm.envs as envs + if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 91abaab78dcb..38012e8dd965 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -546,3 +546,4 @@ def forward( sm_scale=self.scale, logits_soft_cap=self.logits_soft_cap) return output.view(num_tokens, hidden_size) + \ No newline at end of file diff --git a/vllm/distributed/distributed_kv.py b/vllm/distributed/distributed_kv.py new file mode 100644 index 000000000000..9005a325d6bf --- /dev/null +++ b/vllm/distributed/distributed_kv.py @@ -0,0 +1,480 @@ +"""vLLM distributed KV cache transfer API. +These APIs are used in `vllm/worker/model_runner.py`. + +Currently supporting TP and PP. + +Workflow: +- In prefill instance, KV cache sender *buffers* the KV cache send requests +- In decode instance + - KV cache receiver sends the hash of input tokens to sender + - KV cache sender executes send request + - KV cache receiver receives the KV cache +""" +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from collections import defaultdict, deque +from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from copy import deepcopy +import time +import threading + +import torch +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.distributed.group_coordinator import GroupCoordinator +from vllm.logger import init_logger +import vllm.distributed.parallel_state as ps +from vllm import _custom_ops as ops +from vllm.sequence import IntermediateTensors + +assert envs.VLLM_DISAGG_PREFILL_ROLE in [None, "prefill", "decode"], \ + "VLLM_DISAGG_PREFILL_ROLE can only be prefill or decode." + +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE is not None) +IS_KV_PREFILL_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "prefill") +IS_KV_DECODE_INSTANCE: bool = (envs.VLLM_DISAGG_PREFILL_ROLE == "decode") + +# add a tag when sending/recving input hash +DISTRIBUTED_KV_GLOO_TAG = 24857323 + +logger = init_logger(__name__) + +import logging + + +class RankFilter(logging.Filter): + + def filter(self, record): + # Only log if rank is 4 + rank = 1 + try: + rank = torch.distributed.get_rank() + except Exception: + pass + return rank % 4 == 0 + + +for handler in logger.handlers: + handler.addFilter(RankFilter()) + + +class DistributedKVCoordinator(GroupCoordinator): + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + # DO NOT use pynccl here + # Pynccl send is non-blocking + # and it's possible that the memory is freed before the data being sent + # which may happen at high qps + use_pynccl: bool = False, + use_custom_allreduce: bool = False, + use_tpu_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + use_cpu_comm_for_sanity_check: bool = False, + ): + + super().__init__( + group_ranks, + local_rank, + torch_distributed_backend, + use_pynccl, + use_custom_allreduce, + use_tpu_communicator, + use_message_queue_broadcaster, + ) + + # if turned on, will use CPU-based communication to perform a series of sanity check. + # but it adds ~5ms delay, so please turn it off in performance-demanding usecases (e.g. disaggregated prefill) + self.use_cpu_comm_for_sanity_check = use_cpu_comm_for_sanity_check + + # use a threadpool to buffer send request in disaggregated prefill + self.input_hash_to_kv_sending_requests = defaultdict(deque) + self.kv_sending_thread = None + self.input_hash_to_kv_sending_requests_lock = Lock() + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + + torch.set_default_device(self.device) + + def debug_send(self, + tensor: torch.Tensor, + dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """Will send several metadata. Useful for debugging.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + self.send_tensor_dict( + { + "tensor": tensor, + "mean": tensor.float().mean(), + "shape": tensor.shape + }, dst) + + def debug_recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + + result = self.recv_tensor_dict(src) + tensor = result["tensor"] + assert torch.allclose(result["mean"], tensor.float().mean()) + assert result["shape"] == tensor.shape + assert result[ + "shape"] == size, f"The shape sent by sender is {result['shape']} but trying to receive {size}" + return tensor + + def kv_cache_send(self, + input_hash: int, + tensor: Union[torch.Tensor, IntermediateTensors], + is_hidden: bool = False, + dst: Optional[int] = None) -> None: + """Push the KV cache send request into the send buffer""" + """NOTE: `dst` is the local rank of the destination rank.""" + + if self.use_cpu_comm_for_sanity_check: + send_func = self.debug_send + else: + send_func = self.send + + if is_hidden and not ps.get_pp_group().is_last_rank: + + assert isinstance(tensor, IntermediateTensors) + + output = deepcopy(tensor.tensors) + for key in output: + output[key] = output[key].contiguous() + + self.input_hash_to_kv_sending_requests[input_hash].append( + [self.send_tensor_dict, output, dst]) + + else: + + assert isinstance(tensor, torch.Tensor) + + self.input_hash_to_kv_sending_requests[input_hash].append([ + send_func, + # use clone to make sure the tensor is contiguous + tensor.clone(), + dst + ]) + + def kv_cache_recv( + self, + size: torch.Size, + dtype: torch.dtype, + is_hidden: bool = False, + src: Optional[int] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + """Receives a tensor from the src rank (blocking).""" + """This API should be used together with `push`""" + """NOTE: `src` is the local rank of the destination rank.""" + + if self.use_cpu_comm_for_sanity_check: + recv_func = self.debug_recv + else: + recv_func = self.recv + + if is_hidden and not ps.get_pp_group().is_last_rank: + tensor = IntermediateTensors(self.recv_tensor_dict(src)) + else: + tensor = recv_func(size, dtype, src) + + return tensor + + def send_input_hash(self, input_hash: int) -> int: + + logger.debug('[rank%d]: Sending input hash %d to rank %d', + torch.distributed.get_rank(), input_hash, + self.target_rank_for_send) + + # KV cache send go through CPU, and the original `send` only use GPU. + # So create a new group for sending input hash. + input_hash_tensor = torch.tensor([input_hash], device="cpu").long() + torch.distributed.send(input_hash_tensor, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.recv(return_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return return_tensor.item() + + def recv_input_hash(self) -> Optional[int]: + ''' + Receive an input hash, and check if it is already cached + ''' + input_hash_tensor = torch.tensor([0], device="cpu").long() + torch.distributed.recv(input_hash_tensor, + self.target_rank_for_recv, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + input_hash = input_hash_tensor.item() + # a new input hash comes in, see if it is already cached + self.input_hash_to_kv_sending_requests_lock.acquire() + logger.debug('Successfully received input hash %d', input_hash) + if input_hash not in self.input_hash_to_kv_sending_requests: + logger.warning( + f"The KV cache of {input_hash} does not exist. "\ + f"Existing input hash: {list(self.input_hash_to_kv_sending_requests.keys())}") + + # 0 for fail + x = torch.tensor([0], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return None + else: + logger.debug('Input hash %d exists, start sending', input_hash) + + # 1 for success + x = torch.tensor([1], device="cpu").long() + torch.distributed.send(x, + self.target_rank_for_send, + self.cpu_group, + tag=DISTRIBUTED_KV_GLOO_TAG) + return input_hash + + def kv_cache_send_loop(self): + + while True: + logger.debug( + '[rank%d]: Waiting for input hash from rank %d, my keys are %s', + torch.distributed.get_rank(), + self.target_rank_for_recv, + list(self.input_hash_to_kv_sending_requests.keys()), + ) + # wait for a new input hash + # this function will acquire the lock + input_hash = self.recv_input_hash() + if input_hash is None: + self.input_hash_to_kv_sending_requests_lock.release() + continue + + # execute corresponding kv cache sending jobs in request queue + while True: + request = self.input_hash_to_kv_sending_requests[ + input_hash].popleft() + # An empty request: the KV cahe of one request are all sent + if request == []: + break + + request[0](*request[1:]) + + if len(self.input_hash_to_kv_sending_requests[input_hash]) == 0: + logger.debug('Finish input hash %d, free GPU memory...', + input_hash) + del self.input_hash_to_kv_sending_requests[input_hash] + else: + logger.debug( + 'The buffer for input hash %d is not empty, meaning that '\ + 'there are two jobs with identical input.', + input_hash) + + self.input_hash_to_kv_sending_requests_lock.release() + + + def kv_cache_send_ready(self, input_hash: int): + + if self.kv_sending_thread is None: + self.kv_sending_thread = threading.Thread( + target=self.kv_cache_send_loop) + self.kv_sending_thread.start() + + # append an empty list to separate requests + # as there might be identical requests, that has the same input hash + self.input_hash_to_kv_sending_requests[input_hash].append([]) + logger.debug(f'Buffered input hash {input_hash}') + + def kv_cache_recv_start(self, input_hash: int): + # notify the kv cache sender with the input hash id + return self.send_input_hash(input_hash) + + def block_if_buffer_full(self): + + # block vLLM if the KV cache sending buffer is full + # TODO: allow using other policies to handle buffer full + while True: + self.input_hash_to_kv_sending_requests_lock.acquire() + if len(self.input_hash_to_kv_sending_requests.keys()) > 40: + self.input_hash_to_kv_sending_requests_lock.release() + time.sleep(0.1) + else: + self.input_hash_to_kv_sending_requests_lock.release() + break + + +def send_kv_caches_and_hidden_states( + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors], +) -> None: + + input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + # Assumption: current batch is all-prefill requests + assert torch.allclose(model_input.attn_metadata.query_start_loc, + model_input.attn_metadata.seq_start_loc) + assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) + + ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.acquire() + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + input_hash = hash(input_tokens_tuple[start_pos:end_pos]) + + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + kv_cache = kv_caches[i - model_executable.model.start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping[start_pos:end_pos] + + ps.get_disagg_group().kv_cache_send( + input_hash, key_cache[current_slot_mapping]) + ps.get_disagg_group().kv_cache_send( + input_hash, value_cache[current_slot_mapping]) + + ps.get_disagg_group().kv_cache_send( + input_hash, + hidden_or_intermediate_states[start_pos:end_pos], + is_hidden=True) + ps.get_disagg_group().kv_cache_send_ready(input_hash) + + ps.get_disagg_group().input_hash_to_kv_sending_requests_lock.release() + + ps.get_disagg_group().block_if_buffer_full() + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + +def recv_kv_caches_and_hidden_states( + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] +) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool]: + + bypass_model_exec = True + + # This is disagg decode instance, during prefill state + # Need to receive KV from the prefill instance + input_tokens_tuple = tuple(model_input.input_tokens.tolist()) + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + # Assumption: current batch is all-prefill requests + assert torch.allclose(model_input.attn_metadata.query_start_loc, + model_input.attn_metadata.seq_start_loc) + assert torch.all(model_input.attn_metadata.context_lens_tensor == 0) + + hidden_or_intermediate_states_for_one_req = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + input_hash = hash(input_tokens_tuple[start_pos:end_pos]) + num_tokens = slen + + # notify the prefill instance to start sending KVs associated with input_hash + contain = ps.get_disagg_group().kv_cache_recv_start(input_hash) + + # fail to find input_hash in prefill instance + # this can occur but idk why... + if contain == 0: + bypass_model_exec = False + continue + + # receive KV cache from disaggregated prefill instance + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + # get kv cache + kv_cache = kv_caches[i - model_executable.model.start_layer] + # get corresponding layer + layer = model_executable.model.layers[i] + + # get kv cache shape (after sliced by tp) + _, _, num_heads, head_size = kv_cache[0].shape + key = ps.get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype) + value = ps.get_disagg_group().kv_cache_recv( + torch.Size([num_tokens, num_heads, head_size]), + kv_cache[0].dtype) + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append( + ps.get_disagg_group().kv_cache_recv(torch.Size( + [num_tokens, model_executable.config.hidden_size]), + kv_cache[0].dtype, + is_hidden=True)) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to recompute the hidden state + return [], bypass_model_exec + + # concatenate hidden states from different requests + if isinstance(hidden_or_intermediate_states_for_one_req[0], torch.Tensor): + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + else: + # concat the IntermediateTensors + keys = list( + hidden_or_intermediate_states_for_one_req[0].tensors.keys()) + result_its = {} + + for key in keys: + result_its[key] = [] + for its in hidden_or_intermediate_states_for_one_req: + result_its[key].append(its[key]) + result_its[key] = torch.cat(result_its[key], dim=0) + + hidden_or_intermediate_states = IntermediateTensors(result_its) + + logger.debug("[rank%d]: KV recv DONE.", torch.distributed.get_rank()) + return hidden_or_intermediate_states, bypass_model_exec diff --git a/vllm/distributed/group_coordinator.py b/vllm/distributed/group_coordinator.py new file mode 100644 index 000000000000..bfa3c7f3c17c --- /dev/null +++ b/vllm/distributed/group_coordinator.py @@ -0,0 +1,726 @@ +"""vLLM PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). +""" + +from dataclasses import dataclass +from contextlib import contextmanager, nullcontext +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +from torch.distributed import Backend, ProcessGroup + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + else: + self.pynccl_comm = None + + self.ca_comm: Optional[CustomAllreduce] + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + else: + self.ca_comm = None + + from vllm.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator) + self.tpu_communicator: Optional[TpuCommunicator] + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext( + ) if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + ca_comm = self.ca_comm + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_reduce(input_) + + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out + pynccl_comm = self.pynccl_comm + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + elif input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + target_device = value.device + if 'cuda' in target_device: + target_device = self.device + tensor = torch.empty(value.size, + dtype=value.dtype, + device=target_device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the local rank of the destination rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d7ca8fd82e1a..86e26b46c9e9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,7 +8,8 @@ - call `init_distributed_environment` to initialize the distributed environment. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to - initialize the model parallel groups. + initialize the model parallel groups and disaggregated prefill parallel + groups. - any code dealing with the distributed stuff @@ -19,14 +20,17 @@ parallelism, you can skip the model parallel initialization and destruction steps. """ +import time import contextlib import pickle +import logging from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union from unittest.mock import patch +import queue import torch import torch.distributed @@ -34,711 +38,10 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.distributed.group_coordinator import GroupCoordinator +import vllm.distributed.distributed_kv as dist_kv -@dataclass -class GraphCaptureContext: - stream: torch.cuda.Stream - - -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - - -def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: - """Split the tensor dictionary into two parts: - 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. - 2. A list of tensors. - """ - metadata_list: List[Tuple[str, Any]] = [] - tensor_list: List[torch.Tensor] = [] - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - # Note: we cannot use `value.device` here, - # because it contains not only the device type but also the device - # index (e.g. "cuda:0"). We only need the device type. - # receiving side will set the device index. - device = value.device.type - metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) - tensor_list.append(value) - else: - metadata_list.append((key, value)) - return metadata_list, tensor_list - - -class GroupCoordinator: - """ - PyTorch ProcessGroup wrapper for a group of processes. - PyTorch ProcessGroup is bound to one specific communication backend, - e.g. NCCL, Gloo, MPI, etc. - GroupCoordinator takes charge of all the communication operations among - the processes in the group. It can route the communication to - a specific implementation (e.g. switch allreduce implementation - based on the tensor size and cuda graph mode). - """ - - # available attributes: - rank: int # global rank - ranks: List[int] # global ranks in the group - world_size: int # size of the group - # difference between `local_rank` and `rank_in_group`: - # if we have a group of size 4 across two nodes: - # Process | Node | Rank | Local Rank | Rank in Group - # 0 | 0 | 0 | 0 | 0 - # 1 | 0 | 1 | 1 | 1 - # 2 | 1 | 2 | 0 | 2 - # 3 | 1 | 3 | 1 | 3 - local_rank: int # local rank used to assign devices - rank_in_group: int # rank inside the group - cpu_group: ProcessGroup # group for CPU communication - device_group: ProcessGroup # group for device communication - use_pynccl: bool # a hint of whether to use PyNccl - use_custom_allreduce: bool # a hint of whether to use CustomAllreduce - # communicators are only created for world size > 1 - pynccl_comm: Optional[Any] # PyNccl communicator - ca_comm: Optional[Any] # Custom allreduce communicator - mq_broadcaster: Optional[Any] # shared memory broadcaster - - def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - use_pynccl: bool, - use_custom_allreduce: bool, - use_tpu_communicator: bool, - use_message_queue_broadcaster: bool = False, - ): - - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group - - assert self.cpu_group is not None - assert self.device_group is not None - - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{local_rank}") - else: - self.device = torch.device("cpu") - - self.use_pynccl = use_pynccl - self.use_custom_allreduce = use_custom_allreduce - self.use_tpu_communicator = use_tpu_communicator - - # lazy import to avoid documentation build error - from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) - - self.pynccl_comm: Optional[PyNcclCommunicator] - if use_pynccl and self.world_size > 1: - self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, - device=self.device, - ) - else: - self.pynccl_comm = None - - self.ca_comm: Optional[CustomAllreduce] - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - self.ca_comm = CustomAllreduce( - group=self.cpu_group, - device=self.device, - ) - else: - self.ca_comm = None - - from vllm.distributed.device_communicators.tpu_communicator import ( - TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] - if use_tpu_communicator and self.world_size > 1: - self.tpu_communicator = TpuCommunicator(group=self.cpu_group) - - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) - self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: - self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) - - @property - def first_rank(self): - """Return the global rank of the first process in the group""" - return self.ranks[0] - - @property - def last_rank(self): - """Return the global rank of the last process in the group""" - return self.ranks[-1] - - @property - def is_first_rank(self): - """Return whether the caller is the first process in the group""" - return self.rank == self.first_rank - - @property - def is_last_rank(self): - """Return whether the caller is the last process in the group""" - return self.rank == self.last_rank - - @property - def next_rank(self): - """Return the global rank of the process that follows the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group + 1) % world_size] - - @property - def prev_rank(self): - """Return the global rank of the process that precedes the caller""" - rank_in_group = self.rank_in_group - world_size = self.world_size - return self.ranks[(rank_in_group - 1) % world_size] - - @contextmanager - def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None): - if graph_capture_context is None: - stream = torch.cuda.Stream() - graph_capture_context = GraphCaptureContext(stream) - else: - stream = graph_capture_context.stream - - ca_comm = self.ca_comm - maybe_ca_context = nullcontext( - ) if ca_comm is None else ca_comm.capture() - - # ensure all initialization operations complete before attempting to - # capture the graph on another stream - curr_stream = torch.cuda.current_stream() - if curr_stream != stream: - stream.wait_stream(curr_stream) - - with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. - pynccl_comm = self.pynccl_comm - maybe_pynccl_context: Any - if not pynccl_comm: - maybe_pynccl_context = nullcontext() - else: - maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) - with maybe_pynccl_context: - yield graph_capture_context - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - """ - NOTE: This operation will be applied in-place or out-of-place. - Always assume this function modifies its input, but use the return - value as the output. - """ - ca_comm = self.ca_comm - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_reduce(input_) - - if ca_comm is not None: - out = ca_comm.custom_all_reduce(input_) - if out is not None: - return out - pynccl_comm = self.pynccl_comm - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - elif input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) - else: - torch.distributed.all_reduce(input_, group=self.device_group) - return input_ - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - - # For TPUs, use TPU communicator. - tpu_comm = self.tpu_communicator - if tpu_comm is not None and not tpu_comm.disabled: - return tpu_comm.all_gather(input_, dim) - - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor - - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> torch.Tensor: - """ - NOTE: We assume that the input tensor is on the same device across - all the ranks. - NOTE: `dst` is the local rank of the destination rank. - """ - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - if self.rank_in_group == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) - if self.rank_in_group == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor - - def broadcast(self, input_: torch.Tensor, src: int = 0): - """Broadcast the input tensor. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) - return input_ - - def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): - """Broadcast the input object. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj - if self.mq_broadcaster is not None: - assert src == 0, "Message queue broadcaster only supports src=0" - return self.mq_broadcaster.broadcast_object(obj) - if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], - src=self.ranks[src], - group=self.cpu_group) - return obj - else: - recv = [None] - torch.distributed.broadcast_object_list(recv, - src=self.ranks[src], - group=self.cpu_group) - return recv[0] - - def broadcast_object_list(self, - obj_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input object list. - NOTE: `src` is the local rank of the source rank. - """ - assert src < self.world_size, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) - return obj_list - - def send_object(self, obj: Any, dst: int) -> None: - """Send the input object list to the destination rank.""" - """NOTE: `dst` is the local rank of the destination rank.""" - - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - assert dst != self.rank_in_group, ( - "Invalid destination rank. Destination rank is the same " - "as the current rank.") - - # Serialize object to tensor and get the size as well - object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - - size_tensor = torch.tensor([object_tensor.numel()], - dtype=torch.long, - device="cpu") - - # Send object size - - torch.distributed.send(size_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - # Send object - torch.distributed.send(object_tensor, - dst=self.ranks[dst], - group=self.cpu_group) - - return None - - def recv_object(self, src: int) -> Any: - """Receive the input object list from the source rank.""" - """NOTE: `src` is the local rank of the source rank.""" - - assert src < self.world_size, f"Invalid src rank ({src})" - - assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank." - ) - - size_tensor = torch.empty(1, dtype=torch.long, device="cpu") - - # Receive object size - rank_size = torch.distributed.recv(size_tensor, - src=self.ranks[src], - group=self.cpu_group) - - # Tensor to receive serialized objects into. - object_tensor = torch.empty( # type: ignore[call-overload] - size_tensor.item(), # type: ignore[arg-type] - dtype=torch.uint8, - device="cpu") - - rank_object = torch.distributed.recv(object_tensor, - src=self.ranks[src], - group=self.cpu_group) - - assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") - - obj = pickle.loads(object_tensor.numpy().tobytes()) - - return obj - - def broadcast_tensor_dict( - self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): - return tensor_dict - - group = self.device_group - metadata_group = self.cpu_group - assert src < self.world_size, f"Invalid src rank ({src})" - - rank_in_group = self.rank_in_group - if rank_in_group == src: - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.broadcast_object(metadata_list, src=src) - async_handles = [] - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - for async_handle in async_handles: - async_handle.wait() - - else: - metadata_list = self.broadcast_object(None, src=src) - tensor_dict = {} - async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) - async_handles.append(handle) - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - for async_handle in async_handles: - async_handle.wait() - return tensor_dict - - def send_tensor_dict( - self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, - all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Send the input tensor dictionary. - NOTE: `dst` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return tensor_dict - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) - - group = self.device_group - metadata_group = self.cpu_group - - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), f"Expecting a dictionary, got {type(tensor_dict)}" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `send_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - self.send_object(metadata_list, dst=dst) - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip sending empty tensors. - continue - - # send-allgather: send only a slice, then do allgather. - if (all_gather_group is not None - and tensor.numel() % all_gather_size == 0): - tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) - return None - - def recv_tensor_dict( - self, - src: Optional[int] = None, - all_gather_group: Optional["GroupCoordinator"] = None, - ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: - """Recv the input tensor dictionary. - NOTE: `src` is the local rank of the source rank. - """ - # Bypass the function if we are using only 1 GPU. - if not torch.distributed.is_initialized() or self.world_size == 1: - return None - - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) - - group = self.device_group - metadata_group = self.cpu_group - - if src is None: - src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" - - recv_metadata_list = self.recv_object(src=src) - tensor_dict: Dict[str, Any] = {} - for key, value in recv_metadata_list: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor - continue - - # send-allgather: send only a slice, then do allgather. - use_all_gather = (all_gather_group is not None - and tensor.numel() % all_gather_size == 0) - - if use_all_gather: - orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, - -1)[all_gather_rank] - - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=metadata_group) - else: - # use group for GPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) - if use_all_gather: - # do the allgather - tensor = all_gather_group.all_gather( # type: ignore - tensor, dim=0) - tensor = tensor.reshape(orig_shape) - - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - return tensor_dict - - def barrier(self): - """Barrier synchronization among the group. - NOTE: don't use `device_group` here! `barrier` in NCCL is - terrible because it is internally a broadcast operation with - secretly created GPU tensors. It is easy to mess up the current - device. Use the CPU group instead. - """ - torch.distributed.barrier(group=self.cpu_group) - - def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: - """Sends a tensor to the destination rank in a non-blocking way""" - """NOTE: `dst` is the local rank of the destination rank.""" - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.send(tensor, dst) - else: - torch.distributed.send(tensor, self.ranks[dst], self.device_group) - - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" - if src is None: - src = (self.rank_in_group - 1) % self.world_size - - tensor = torch.empty(size, dtype=dtype, device=self.device) - pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.recv(tensor, src) - else: - torch.distributed.recv(tensor, self.ranks[src], self.device_group) - return tensor - - def destroy(self): - if self.device_group is not None: - torch.distributed.destroy_process_group(self.device_group) - self.device_group = None - if self.cpu_group is not None: - torch.distributed.destroy_process_group(self.cpu_group) - self.cpu_group = None - if self.pynccl_comm is not None: - self.pynccl_comm = None - if self.ca_comm is not None: - self.ca_comm = None - if self.mq_broadcaster is not None: - self.mq_broadcaster = None _WORLD: Optional[GroupCoordinator] = None @@ -749,10 +52,10 @@ def get_world_group() -> GroupCoordinator: return _WORLD -def init_world_group(ranks: List[int], local_rank: int, +def init_world_group(ranks: List[List[int]], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( - group_ranks=[ranks], + group_ranks=ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, @@ -804,6 +107,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_DISAGG: Optional[dist_kv.DistributedKVCoordinator] = None + + +def get_disagg_group() -> dist_kv.DistributedKVCoordinator: + assert _DISAGG is not None, ( + "disaggregated prefill parallel group is not initialized") + return _DISAGG + @contextmanager def graph_capture(): @@ -835,6 +146,33 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def include_decoding_groups_if_disagg_enabled( + groups: List[List[int]], + world_size: int, +) -> List[List[int]]: + """ + Include the distributed group for decode + Only for disaggregated prefill + + Example: + Original group: [ [0,1], [2,3] ], world_size = 4 + Extended: [ [0,1], [2,3], [4,5], [6,7] ] + Arguments: + groups: original distributed group + world_size: the vLLM world size, which is half of torch.distributed.get_world_size() + """ + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + new_groups = [] + for group in groups: + new_groups.append([rank for rank in group]) + for group in groups: + new_groups.append([rank + world_size for rank in group]) + return new_groups + else: + return groups + + def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -851,11 +189,29 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + maybe_disagg_world_size = world_size + maybe_disagg_rank = rank + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + maybe_disagg_world_size = world_size * 2 + logger.debug("Disaggregated prefill enabled.") + if dist_kv.IS_KV_PREFILL_INSTANCE: + # for prefill, the ranks are [0, world_size) + maybe_disagg_rank = rank + else: + # this is decode instance. + # offset global rank by tp * pp (which is world_size) + maybe_disagg_rank = rank + world_size + + logger.debug( + f"Before: world size {maybe_disagg_world_size}, rank {maybe_disagg_rank}" + ) + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, - world_size=world_size, - rank=rank) + world_size=maybe_disagg_world_size, + rank=maybe_disagg_rank) + logger.debug("torch.distributed initialized") # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -866,10 +222,19 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank + global _WORLD if _WORLD is None: - ranks = list(range(torch.distributed.get_world_size())) + ranks = [[i for i in range(world_size)]] + # offset the distributed group + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + ranks = include_decoding_groups_if_disagg_enabled( + ranks, world_size) + _WORLD = init_world_group(ranks, local_rank, backend) + logger.debug("_WORLD initialized for rank %d", + torch.distributed.get_rank()) + time.sleep(5) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") @@ -901,15 +266,37 @@ def initialize_model_parallel( are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. + + + Disaggregated prefill will also initialize its process group using this function. + Changes: + - vLLM world size: unchanged (tp * pp) + - torch.distributed.get_world_size(): + - 2 * tp * pp + - Why: torch.distributed package sees 2 vLLM instances (prefill and decode) + - Global rank: + - [0, tp * pp) for prefill + - [tp * pp, 2 * tp * pp) for decode + - Parallel groups + - Extend _WORLD, _TP and _PP using `include_decoding_groups_if_disagg_enabled` + - Add a new parallel group `_DISAGG` for disaggregated prefill + - [ [0, tp * pp], [1, tp * pp + 1], .. ] + - Local rank: unchanged """ + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + # Disaggregated prefill enabled + # The world_size for this vLLM instance is tp * pp, but torch.distributed contains 2 vLLM instances, its world size is 2 * tp * pp + # Adjust the world_size to match. + world_size = world_size // 2 + + if (world_size + != tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " @@ -926,12 +313,14 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True) + logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -943,11 +332,39 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + global _DISAGG + logger.debug("Disaggregated prefill enabled, create _DISAGG group") + group_ranks = [] + for i in range(world_size): + # prefill local rank: i + # decode global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = dist_kv.DistributedKVCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + ) + # follow by a warmup, to warmup nccl + # necessary, as NCCL may not be warmed up when tp and pp are both 1. + temp_tensor = torch.tensor([1.]).to(_DISAGG.device) + if dist_kv.IS_KV_PREFILL_INSTANCE: + _DISAGG.send(temp_tensor) + else: + recv_tensor = _DISAGG.recv(temp_tensor.shape, temp_tensor.dtype) + assert torch.allclose(temp_tensor, recv_tensor) + logger.debug("_DISAGG initialized for rank %d", + torch.distributed.get_rank()) def ensure_model_parallel_initialized( @@ -990,7 +407,7 @@ def model_parallel_is_initialized(): def patch_tensor_parallel_group(tp_group: GroupCoordinator): """Patch the tp group temporarily until this function ends. - This method is for draft workers of speculative decoding to run draft model + This method is for draft workers of speculative decode to run draft model with different tp degree from that of target model workers. Args: @@ -1033,6 +450,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DISAGG + if _DISAGG: + _DISAGG.destroy() + _DISAGG = None + def destroy_distributed_environment(): global _WORLD diff --git a/vllm/envs.py b/vllm/envs.py index 089a39d8e029..07a7b647f6bc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ def get_default_config_root(): # used when the frontend api server is running in multi-processing mode, # to communicate with the backend engine process over ZMQ. 'VLLM_RPC_PORT': - lambda: int(os.getenv('VLLM_PORT', '5570')), + lambda: int(os.getenv('VLLM_RPC_PORT', '5570')), # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers @@ -329,6 +329,11 @@ def get_default_config_root(): "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), + # Specify the role of current vllm instance + # Value can be "prefill", "decode". + "VLLM_DISAGG_PREFILL_ROLE": + lambda: os.getenv("VLLM_DISAGG_PREFILL_ROLE", None), + # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3e77af0e2032..300e9a33eba5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -43,7 +44,7 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + get_ip(), get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) return dict( model_config=self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37..ba222f8b5e40 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -8,6 +8,7 @@ import torch +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker @@ -82,7 +83,7 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) + "127.0.0.1", get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4a6825c01fcf..17f4d3633886 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import vllm.envs as envs +import vllm.distributed.distributed_kv as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -226,8 +227,11 @@ def sort_by_driver_then_worker_ip(worker): # solves this issue, as it always works for communication inside # the node. driver_ip = "127.0.0.1" + # force vLLM to use the port specified by envs.VLLM_PORT + # this port will be binded by prefill instance + # but the decode instance must use that port to init torch.distributed distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ diff --git a/vllm/utils.py b/vllm/utils.py index 51bd72977a22..fa5452335264 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -388,13 +388,20 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port(port: Optional[int] = None) -> int: +def get_open_port(port: Optional[int] = None, force: bool = False) -> int: if port is None: # Default behavior here is to return a port for multi-gpu communication port = envs.VLLM_PORT if port is not None: + if force and port is not None: + # force vLLM to use envs.VLLM_PORT for torch.distributed init + # This is because this port will binded by prefill instance + # But both prefill and decode instance need to use this port to + # initialize torch.distributed + return port while True: try: + logger.error('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f9c26e0c318b..4d8105bde2c4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,6 +12,8 @@ import torch.distributed import torch.nn as nn +import vllm.distributed.distributed_kv as dist_kv + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -28,7 +30,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) -from vllm.distributed import get_pp_group +from vllm.distributed import get_tp_group, get_pp_group, get_disagg_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -60,6 +62,9 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +import vllm.envs as envs +from vllm import _custom_ops as ops + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1360,20 +1365,60 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} - hidden_or_intermediate_states = model_executable( + + # check if the current run is profiling + is_profile_run = (kv_caches is None) or (kv_caches[0] is None) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + # for disaggregated prefilling: allow bypassing model execution + bypass_model_exec = False + + # Recv kv cache for disaggregated prefill + # Skip model execution if all required KV cache are received + if all([ + is_prefill_run, + dist_kv.IS_KV_DECODE_INSTANCE, + not is_profile_run]): + + hidden_or_intermediate_states, bypass = \ + dist_kv.recv_kv_caches_and_hidden_states( + model_executable, + model_input, + kv_caches, + ) + if bypass: + bypass_model_exec = True + + if not bypass_model_exec: + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), + device=self.device), **seqlen_agnostic_kwargs) + + # Send KV cache for disaggregated prefill + if all([ + is_prefill_run, + dist_kv.IS_KV_PREFILL_INSTANCE, + not is_profile_run]): + + dist_kv.send_kv_caches_and_hidden_states( + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states - + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) @@ -1386,6 +1431,7 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None