-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
Motivation.
Nowadays, many new applications including multi-turn conversations, multi-modality and multi-agent, require a significant amount of KV cache. Such applications generally have a shared prompt for multiple requests, and recomputing them each time can take significant time for prefilling. Suppose the length of shared prompt is
However, the current vLLM rarely uses the secondary storage tier (DRAM). It only swaps out the running sequences when they fail to allocate blocks for the newly generated token which rarely happens for different workloads (what I have tested included sharedGPT, UltraChat, LooGle, Toolbench). When vLLM allocates blocks for prefill sequences, it discards the content of the GPU blocks in the evictor. Instead, it can evict the blocks to the CPU DRAM so that a new request that shares the same prompt can load it without recomputing again.
I have some motivation numbers to support the RFC. First is the cost of recomputation against that of swap-in and compute. Here I use tp=1, pp=1 on 1 A100. The model is togethercomputer/Llama-2-7B-32K-Instruct. The x-axis is the token length. The Y-axis is the TTFT for 1 single sequence (No queueing). In the figure, prefill w/o cache means recompute. prefill w cache means the kv cache is in HBM and we compute for the very next token. prefill w cache + swapin means the kv cache is in DRAM and we first swap in the cache from DRAM to HBM and then compute the very next token. swap-in is the swap in time for the previous case. Here I did not enable any overlapping between swap-in and execution.
As in the figure, the recompute time increases quadratically and the latency is higher than prefill w cache + swapin from 1024 tokens.
I have also implemented a prototype for this RFC, enabling the eviction of blocks to DRAM from HBM when prefix-caching is set. In the next figure, I tested longdep_qa.json
in the LooGLE benchmark[3]. The dataset's average prompt length is about 15k. I used the same setup as the previous experiment and set the request rate to 1000. Here, the x-axis is the DRAM size, and the y-axis is either the mean TTFT or TPOT.
There are still some variances in the benchmark, which I am looking into. Another issue is that it seems the performance decreases as the DRAM is larger than a threshold. I am still investigating what happened. The performance decrease at the beginning is expected because of the pure cost of data movement.
References
[1] Gao, Bin, et al. "{Cost-Efficient} Large Language Model Serving for Multi-turn Conversations with {CachedAttention}." 2024 USENIX Annual Technical Conference (USENIX ATC 24). 2024.
[2] Sheng, Ying, et al. "Flexgen: High-throughput generative inference of large language models with a single gpu." International Conference on Machine Learning (ICML 2023) 2023.
[3] Li, Jiaqi, et al. "LooGLE: Can Long-Context Language Models Understand Long Contexts?." arXiv preprint arXiv:2311.04939 (2023).
Proposed Change.
This patch takes three steps in my mind:
- Implement the basic functionality to enable eviction and promotion to and from the DRAM
- Enable the overlapping of KV cache transfer with the model execution (instead of finishing the transmission and then starting execution, transmission and execution are done layer-by-layer)
- Enabling selective eviction to DRAM based on prompt length.
I have done the prototype for 1 which does not require too much change. The necessary changes include:
- Change the
block_manager_v[1,2].py
to allocate CPU blocks in eviction and query thecpu_allocator
for a block hash - Change the
scheduler.py
to addblocks_to_swap_in
blocks_to_swap_out
for_scheduler_prefill
- Change the order of swap_out and swap_in in
worker.py
The following are the things left for stage 1:
- Extensive testing. Right now I only wrote tests for a very small case and checked it worked for a real benchmark (match the input output)
- Add a free timestamp for CPU blocks because the
access_blocks_for_seqs
is never invoked for CPU blocks. This leads to some undesired behavior in eviction policy - Check what to do when one block is dropped from both CPU and GPU memory in the middle of a prompt with both tokens before and after cached (the oooooxxooooxx pattern). We probably need to change the kernel to support this case. I saw such a case happen in testing.
- Support for encoder, decoder model
- Support for
block_manager_v2
- Check what happened when chuncked_prefill enabled
Some changes I deemed necessary for stage 2:
- Change the model to support layer-by-layer transmission
Some changes I deemed necessary for stage 3:
- Add a dry run to find the tipping point.
Feedback Period.
2-3 weeks
CC List.
@zhuohan123 @robertgshaw2-neuralmagic @zcnrex @sh1ng @SageMoore @comaniac @youkaichao @andoorve
Any Other Things.
No response