From 90366fc11dffffde053c60c28f2169252f32ec21 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 7 Mar 2025 22:21:04 +0000 Subject: [PATCH] handle mla in kv_cache_interface Signed-off-by: Tyler Michael Smith --- vllm/v1/kv_cache_interface.py | 13 ++++++++----- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/tpu_model_runner.py | 7 ++++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index dfef1039fce2..1f885c10c8c3 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -23,9 +23,9 @@ class KVCacheSpecBase: def type_id(self) -> str: """ The type identifier of this KV cache. - Return different strings for layers with different KV cache type (e.g., - different number of tokens like full attention vs sliding window - attention, different KV cache size per token like layers with different + Return different strings for layers with different KV cache type (e.g., + different number of tokens like full attention vs sliding window + attention, different KV cache size per token like layers with different number of heads) Returns: @@ -59,6 +59,7 @@ class FullAttentionSpec(KVCacheSpecBase): num_kv_heads: int head_size: int dtype: torch.dtype + use_mla: bool @property def type_id(self) -> str: @@ -66,7 +67,9 @@ def type_id(self) -> str: @property def page_size_bytes(self) -> int: - return 2 * self.block_size * self.num_kv_heads * self.head_size \ + # For MLA we only store a single latent vector + coef = 1 if self.use_mla else 2 + return coef * self.block_size * self.num_kv_heads * self.head_size \ * get_dtype_size(self.dtype) def bytes_for_tokens(self, num_tokens: int) -> int: @@ -104,7 +107,7 @@ class KVCacheConfig: 2. (not implemented yet) A model with the same number of full attention layers and sliding window attention layers: two groups, one for full attention layers and one for sliding window attention layers. - 3. (not implemented yet) A model with 2 full attention layers and 4 sliding + 3. (not implemented yet) A model with 2 full attention layers and 4 sliding window attention layers: three groups, (full * 2), (sw * 2), (sw * 2). """ groups: list[list[str]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5cd7e25edcaa..0cdf8f1ab8cc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1460,13 +1460,14 @@ def get_kv_cache_spec(self) -> KVCacheSpec: forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: KVCacheSpec = {} for layer_name, attn_module in forward_ctx.items(): if isinstance(attn_module, FusedMoE): continue # TODO: Support other attention modules, e.g., sliding window, - # cross-attention, MLA. + # cross-attention assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -1474,7 +1475,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, - ) + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d4ebb3adcf8d..7ac07906fabb 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -303,10 +303,10 @@ def get_model(self) -> nn.Module: def get_kv_cache_spec(self) -> KVCacheSpec: """ - Generates the KVCacheSpec by parsing the kv cache format from each + Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache + KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ @@ -323,6 +323,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, + use_mla=False, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): @@ -764,7 +765,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: - kv_cache_config: Configuration for the KV cache, including the KV + kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: