Skip to content

Conversation

sarckk
Copy link
Collaborator

@sarckk sarckk commented May 15, 2025

Motivation

Some models like Tencent-Hunyuan-Large (#10043) and Hymba-1.5B-Base (#10783) use cross-layer KV sharing (e.g. Cross-Layer Attention). This PR adds the ability for KV caches to be shared between attention layers.

Design

This PR adds a new argument kv_sharing_target_layer_name: Optional[str] = None to the Attention layer class. This is only supported in V1. To have an Attention layer not allocate its own KV cache and instead share the KV cache with another layer (referred to as the target layer), you can pass in the fully-qualified name of the Attention layer in the target layer e.g. model.layers.0.attn. The arg kv_sharing_target_layer_name is only valid if a) it refers to an Attention layer, b) it has the same attn type (e.g. decoder) as the current layer, and c) it comes before the current layer. It is referred to the as the target layer because during attention, the current layer will use its queries and perform the attention op with the keys and values tensor from the KV cache of the target layer.

If an Attention layer has a valid kv_sharing_target_layer_name defined, then we skip creating a KVCacheSpec for it, while recording the mapping in self.shared_kv_cache_layers:

https://github.com/vllm-project/vllm/blob/89450fc323e9eee05cbba76fb5b9a0d29f7038d8/vllm/v1/worker/gpu_model_runner.py#L2142-L2152

During KV cache initialization, KV cache management logic will continue as if this layer did not exist and will not allocate a KV cache for the layer. The KV cache for these layers will instead be a reference to the allocated KV caches of the matching target layers, which enables the memory savings of cross-layers KV sharing.

https://github.com/vllm-project/vllm/blob/89450fc323e9eee05cbba76fb5b9a0d29f7038d8/vllm/v1/worker/utils.py#L107-L110

We also add these layers to the list of layer names kept by each KV cache group, as this ensures that each layer is assigned its own attention metadata. From the perspective of the Attention layer, it does not know where the key and value caches are coming from.

The memory savings of cross-layer KV sharing allows a given amount of memory to accommodate longer context lengths or enable more request to be processed in parallel.


Testing

Sanity Check

As a sanity check that the implementation is working, I made all layers after the 18th layer in Qwen/Qwen3-8B (36 layers total) and printed out the id() of the kv cache used in attention forward:

model.layers.0.self_attn.attn => 139678446053136
model.layers.1.self_attn.attn => 139678446059136
…
model.layers.15.self_attn.attn => 139678446045456
model.layers.16.self_attn.attn => 139678446055056
model.layers.17.self_attn.attn => 139678446050736
model.layers.18.self_attn.attn => 139678446050736
model.layers.19.self_attn.attn => 139678446050736
…
model.layers.32.self_attn.attn => 139678446050736
model.layers.33.self_attn.attn => 139678446050736
model.layers.34.self_attn.attn => 139678446050736
model.layers.35.self_attn.attn => 139678446050736 

As expected, layers 19 to 36 are re-using the KV cache allocated by layer 18.

Unit Tests

All newly added unit tests pass:

pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"

Evals

checked the score of gsm8k before and after my PR on Qwen/Qwen3-8B:

lm_eval --model vllm --tasks gsm8k --model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1 --batch_size auto

before PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8795|±  |0.0090|
|     |       |strict-match    |     5|exact_match|↑  |0.8734|±  |0.0092|

After PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8802|±  |0.0089|
|     |       |strict-match    |     5|exact_match|↑  |0.8734|±  |0.0092|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels May 15, 2025
@sarckk sarckk marked this pull request as ready for review May 15, 2025 17:31
@heheda12345 heheda12345 self-requested a review May 16, 2025 01:51
Copy link

mergify bot commented May 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sarckk
Copy link
Collaborator Author

sarckk commented May 20, 2025

entrypoints test failure is unrelated and failing on trunk (see https://buildkite.com/vllm/fastcheck/builds/24385)

@sarckk
Copy link
Collaborator Author

sarckk commented May 20, 2025

@heheda12345 could you take a look?

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my late review. Some points that I want to discuss:

  1. The name "KV sharing". Do you think "reuse" is a better name? I want to discuss more about it because in #17996, I need to let multiple layers sharing the same memory pool but with different block_ids, and I think we need to distinguish between the "sharing" in this PR and that PR. From my understanding, reuse is more accurate here because layers are not equal. The first layer updates the kv cache and the following layers just reuse the first layer. But open to discussion. We need to agree on a name and keep it consistent in this PR.
  2. One model with kv sharing should use less memory per block than another model with the same model config but without kv sharing. Where do you implement this logic now?
  3. Is KV sharing compatible with kv connectors now?
  4. I think we can make KV sharing more implicit. Basically, I think it is possible to avoid changing code inside v1/core & kv_cache_interface.py. kv_cache_manager & kv_cache_utils don’t need to know about kv sharing. They can run as if the layers without kv sharing does not exist. To mimic it, we can only return layers with kv_sharing_target_layer_idx is None in GPUModelRunner.get_kv_cache_spec.
  5. I prefer to use kv_sharing_target_layer_name than kv_sharing_target_layer_idx as it has no ambiguity. For example, in bart, we will have both decoder.layers.1.self_attn and decoder.layers.1.encoder_attn. Both layer index is 1.
  6. Add check for we only support kv sharing in v1

Copy link

mergify bot commented May 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 21, 2025
@sarckk
Copy link
Collaborator Author

sarckk commented May 22, 2025

@heheda12345 thanks for taking a look. To answer your questions:

  1. I prefer 'KV sharing' simply because it seems to be the academic term for this kind of thing (e.g. see https://arxiv.org/abs/2410.14442), whereas 'KV reuse' seems to be used to refer to something else (e.g. prefix caching, https://developer.nvidia.com/blog/introducing-new-kv-cache-reuse-optimizations-in-nvidia-tensorrt-llm/)

  2. One model with kv sharing should use less memory per block than another model with the same model config but without kv sharing.

I didn't quite understand why it would be "less memory per block". I think we'll just have less physical KV blocks being used? Here is where the core memory savings would be coming from, by not allocating if there is a target layer for KV sharing. I might be missing some other implementation details here, let's chat offline?

  1. Is KV sharing compatible with kv connectors now?

Not at the moment, I believe

  1. To mimic it, we can only return layers with kv_sharing_target_layer_idx is None

I explored this design but I remember the complexity was just offloaded to a later stage as we needed to handle KV allocation for layers without a KV cache spec anyways. But I think the APIs around KV cache groups have changed considerably since then, let me take a look again.

  1. Yes, this is a good point. Some models explicitly keep track of FQN for each layer so it shouldn't be difficult. I'll make this change.

  2. Yes, I will add this check.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Sure. Let's use sharing. Pls unify the concept in this PR.
  2. We have less physical memory per KV block, thus we can increase num_gpu_blocks. Where is this logic?
  3. What is the blocker for making it compatible with KV connector?
  4. as we needed to handle KV allocation for layers without a KV cache spec" I think it may be possible to add a function in initialize_kv_cache to handle all logic. Basically, that function needs:
    1. pointing the Attention.kv_cache to the target layer like https://github.com/vllm-project/vllm/blob/b0d8b5968d6c2646ca9b43cd1a175adf87d39651/vllm/v1/worker/gpu_model_runner.py#L2004
    2. adding the shared layer to the kv cache group of its target layer to help this loop https://github.com/vllm-project/vllm/blob/b0d8b5968d6c2646ca9b43cd1a175adf87d39651/vllm/v1/worker/gpu_model_runner.py#L620

But not sure whether I miss any complexity.
5 & 6: SG!
BTW, most merge conflict comes from a temporary revert #18459. I think we can just work on the current branch now without rebase.

@sarckk
Copy link
Collaborator Author

sarckk commented May 30, 2025

updated to address comments.

For KV connector, can you at least try https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/disaggregated-prefill-v1 with a local model?

tried and it still works

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job! Really appreciate the detailed tests and input verification. I think this PR is good except some very small items.

Copy link

mergify bot commented Jun 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sarckk.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sarckk
Copy link
Collaborator Author

sarckk commented Jun 2, 2025

@heheda12345 addressed comments. could you take another look?

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Only some small comments.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for your contribution.

@heheda12345 heheda12345 enabled auto-merge (squash) June 3, 2025 08:27
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 3, 2025
sarckk added 6 commits June 3, 2025 07:09
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
Signed-off-by: Yong Hoon Shin <[email protected]>
auto-merge was automatically disabled June 3, 2025 15:22

Head branch was pushed to by a user without write access

@heheda12345 heheda12345 enabled auto-merge (squash) June 3, 2025 17:09
@heheda12345 heheda12345 merged commit bdf1396 into vllm-project:main Jun 3, 2025
67 checks passed
@sarckk sarckk deleted the kv-sharing branch June 12, 2025 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants