Skip to content

Conversation

mengzhu28
Copy link

@mengzhu28 mengzhu28 commented Feb 17, 2025

TL;DR

In V1, swap GPU KV cache blocks to CPU upon eviction and swap them back if there's a cache hit.

Swap Strategy

CPU → GPU swap-in happens naturally when requests hit the cache (unless we do prefetching).
GPU → CPU swap-out can be handled in two ways:

  1. Eagerly: Immediately after a request completes and its blocks are freed.
  2. Lazily: When evicting a GPU block while scheduling new requests.

This PR adopts (2) to minimize unnecessary swaps. However, the downside is that the swap-out overhead might be exposed.

Ideally, an optimal approach would asynchronously offload X cache blocks at a certain cadence (e.g., hidden behind the main CUDA graph) while maintaining free GPU block headroom. This would add complexity and is left for future work.

Implementation

This PR builds on the excellent V1 KV cache manager, blend in with the existing interface.
Newly introduced metadata states:

  • cpu_block_pool and cached_block_hash_to_cpu_block mirror their GPU counterparts.

High-Level Flow:

  • The KV cache manager accumulates swap-in/out decisions during each scheduling cycle.
  • These swap decisions are then "flushed" to the scheduler output, allowing model runners to issue aggregated swap calls before model execution, minimizing dispatch overhead.

For simplicity, we avoid threading the scheduler output through multiple KV cache manager calls. Instead, swap-related data is accumulated in step_* fields (e.g., step_h2d_swap_map).
A new end_schedule_step callback resets them at the end of each scheduling iteration. (Open to alternative designs.)

CPU Cache Eviction Policy

We currently adopt a simple round-robin strategy to do CPU cache eviction. LRU will be added in a followup PR.

User Configuration:

We reuse the existing --swap-space flag (previously unused in V1) to control the number of CPU blocks.
Whether to change the default (currently 4GB) remains up for discussion.

Benchmark

TBA

TODO

  • write tests
  • benchmarks and profiling
  • docs

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link

mergify bot commented Feb 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mengzhu28.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@BigCousin-z
Copy link

Is V0 Support?

@WoosukKwon WoosukKwon self-assigned this Feb 18, 2025
@mengzhu28 mengzhu28 force-pushed the mzhu/cpu_offload branch 2 times, most recently from 8d7835f to cc4a3e2 Compare February 19, 2025 01:16
@mergify mergify bot removed the needs-rebase label Feb 19, 2025
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 19, 2025
@ywang96 ywang96 marked this pull request as ready for review February 19, 2025 20:30
@WoosukKwon
Copy link
Collaborator

Hi @mengzhu28, thanks for submitting the great PR! I will reach out to you offline.

Copy link

mergify bot commented Mar 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mengzhu28.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 15, 2025
Copy link
Contributor

@maobaolong maobaolong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mengzhu28 Thanks for this great work on V1, offloading KV cache to CPU can gains performance to TTFT and Throughput, just thinking about the next further step base on this PR, may be vllm can support offloading KV cache to Disk as followup work?

I left a comment inline about abstraction, please take a look, thanks.

# The following swap maps are accumulated over a scheduling step.
# Then they are "flushed" as part of the scheduler output.
# GPU block ID -> CPU block ID
self.step_d2h_swap_map: Dict[int, int] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please do an abstract to support Offload to Disk in the future? If we did this abstraction, the data structure can be [ [src_device, dst_device] -> swap_map[src_block_id -> dst_block_id] ]. Any throughs?

kv_caches: Dict[str, torch.Tensor],
forward_context: Dict[str, "Attention"],
runner_kv_caches: List[torch.Tensor],
forward_context: Dict[str, "Attention"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change these parameters sequences can make more sense but in the other hand, it introduce more extra code changes.

@WoosukKwon
Copy link
Collaborator

@mengzhu28 Could you please rebase the PR?

@mengzhu28
Copy link
Author

@WoosukKwon as discussed offline, created RFC #16144

@chunxiaozheng
Copy link
Contributor

Would it be better to abstract the CPU offloading related functions into a new class and add a parameter to enable it?

Signed-off-by: Meng Zhu <[email protected]>
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
if computed_blocks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GPU hit must be before the CPU, so here we should first try pop() from the computed_cpu_blocks .

@josephrocca josephrocca mentioned this pull request Jun 15, 2025
66 tasks
@orozery orozery mentioned this pull request Jun 19, 2025
1 task
Copy link

github-actions bot commented Jul 9, 2025

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Jul 9, 2025
@jiawei-liang
Copy link

helo,can this support 1-cpu/n-gpu in one host situation?

@github-actions github-actions bot added unstale Recieved activity after being labelled stale and removed stale Over 90 days of inactivity labels Jul 11, 2025
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added stale Over 90 days of inactivity and removed unstale Recieved activity after being labelled stale labels Oct 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase ready ONLY add when PR is ready to merge/full CI is needed stale Over 90 days of inactivity tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants