diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index bb2f36271103..697d134f2018 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -6,6 +6,7 @@ import torch +from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey @@ -184,6 +185,31 @@ def fused_output_quant_supported(self, quant_key: QuantKey): class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + indexer: Optional[object] = None, + ) -> None: + raise NotImplementedError + @abstractmethod def forward( self, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b429c74aa559..9f43cb31218f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Callable, Optional +from typing import Callable, Optional, cast import torch import torch.nn as nn @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.attention import AttentionType -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target @@ -23,7 +23,10 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod @@ -131,8 +134,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - use_sparse: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -192,8 +193,6 @@ def __init__( # the quant op after this attention layer. self._o_scale_float: Optional[float] = None - self.use_mla = use_mla - self.use_sparse = use_sparse self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -229,9 +228,8 @@ def __init__( dtype, kv_cache_dtype, block_size, - use_mla=use_mla, + use_mla=False, has_sink=self.has_sink, - use_sparse=use_sparse, ) else: self.attn_backend = attn_backend @@ -349,19 +347,15 @@ def forward( output_shape = output_shape if output_shape is not None else query.shape output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] - # We skip reshaping query, key and value tensors for the MLA - # backend since these tensors have different semantics and are - # processed differently. - if not self.use_mla: - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -570,6 +564,218 @@ def forward( return out.reshape(bsz, q_len, -1) +class MLAAttention(nn.Module, AttentionLayerBase): + """Multi-Head Latent Attention layer. + + This class takes query, and compressed key/value tensors as input. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + kv_b_proj: ColumnParallelLinear, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_sparse: bool = False, + indexer: Optional[object] = None, + ): + super().__init__() + self.num_heads = num_heads + self.scale = scale + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.head_size = kv_lora_rank + qk_rope_head_dim + self.layer_name = prefix + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + calculate_kv_scales = False + self.kv_cache_dtype = kv_cache_dtype + + dtype = torch.get_default_dtype() + self.attn_backend = get_attn_backend( + self.head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=True, + use_sparse=use_sparse, + ) + impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls()) + self.impl = impl_cls( + num_heads=self.num_heads, + head_size=self.head_size, + scale=self.scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=self.kv_cache_dtype, + logits_soft_cap=None, + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=None, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=kv_b_proj, + indexer=indexer, + ) + + self.use_direct_call = not current_platform.opaque_attention_op() + + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + self.kv_cache = [ + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) + ] + + # Align with Attention's scale attributes for MLA backends. + + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # Host-side mirrors used by some attention backends + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + self._o_scale_float: Optional[float] = None + + self.use_sparse = use_sparse + + # Initialize q/k/v range constants. + try: + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + except torch.cuda.OutOfMemoryError: + # Keep defaults if allocation fails; not critical for init. + pass + + def forward( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + + # Mirror Attention.forward scale calculation path + if self.calculate_kv_scales and getattr( + attn_metadata, "enable_kv_scales_calculation", False + ): + self.calc_kv_scales(q, kv_c_normed, k_pe) + + if self.attn_backend.accept_output_buffer: + output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + output=output, + ) + return output + else: + return self.impl.forward( + self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata + ) + else: + if self.attn_backend.accept_output_buffer: + output = torch.zeros(output_shape, dtype=q.dtype, device=q.device) + torch.ops.vllm.unified_mla_attention_with_output( + q, + kv_c_normed, + k_pe, + output, + self.layer_name, + ) + return output + else: + # We can still access forward context to check calculation flag + if self.calculate_kv_scales: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + if getattr(attn_metadata, "enable_kv_scales_calculation", False): + self.calc_kv_scales(q, kv_c_normed, k_pe) + return torch.ops.vllm.unified_mla_attention( + q, + kv_c_normed, + k_pe, + self.layer_name, + ) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + def calc_kv_scales( + self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor + ) -> None: + """Optional scale calculation for MLA inputs. + + Mirrors Attention.calc_kv_scales. Not all MLA backends require this + """ + # Use safe defaults if ranges are not present + q_range = getattr(self, "q_range", torch.tensor(1.0)) + k_range = getattr(self, "k_range", torch.tensor(1.0)) + v_range = getattr(self, "v_range", torch.tensor(1.0)) + + self._q_scale.copy_(torch.abs(q).max() / q_range) + # kv_c_normed is the compressed KV representation; use it for k/v + kv_abs_max = torch.abs(kv_c_normed).max() + self._k_scale.copy_(kv_abs_max / k_range) + self._v_scale.copy_(kv_abs_max / v_range) + self._q_scale_float = self._q_scale.item() + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + self.calculate_kv_scales = False + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + + def wait_for_kv_layer_from_connector(layer_name: str): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -726,3 +932,93 @@ def unified_attention_with_output_fake( fake_impl=unified_attention_with_output_fake, tags=tag_cudagraph_unsafe, ) + + +def unified_mla_attention( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + +def unified_mla_attention_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(q).contiguous() + + +direct_register_custom_op( + op_name="unified_mla_attention", + op_func=unified_mla_attention, + mutates_args=[], + fake_impl=unified_mla_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_mla_attention_with_output( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self: MLAAttention = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward( + self, + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +def unified_mla_attention_with_output_fake( + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_mla_attention_with_output", + op_func=unified_mla_attention_with_output, + mutates_args=["output", "output_block_scale"], + fake_impl=unified_mla_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9346bfa6307a..f47fec12d7f9 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -380,6 +380,8 @@ class CompilationConfig: _attention_ops: ClassVar[list[str]] = [ "vllm.unified_attention", "vllm.unified_attention_with_output", + "vllm.unified_mla_attention", + "vllm.unified_mla_attention_with_output", "vllm.mamba_mixer2", "vllm.mamba_mixer", "vllm.short_conv", diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index b8e99226d13e..4b397a058dcd 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,7 +5,7 @@ import torch -from vllm.attention import Attention +from vllm.attention.layer import MLAAttention from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig @@ -30,8 +30,9 @@ class MLAModules: @CustomOp.register("multi_head_latent_attention") -class MultiHeadLatentAttention(CustomOp): - """MLA layer registered as CustomOp. +class MultiHeadLatentAttentionWrapper(CustomOp): + """MLA layer registered as CustomOp to allow OOT backends to add + custom implementations of the outer MLA layer (including rope & o_proj). Note that currently MLA ignores the enable/disable mechanism of CustomOp because there is only one in-tree implementation in forward_native. TODO: implement this with a new PluggableLayer mechanism. @@ -87,30 +88,19 @@ def __init__( self.topk_tokens = self.indexer.topk_tokens self.topk_indices_buffer = mla_modules.topk_indices_buffer - # In the MLA backend, kv_cache includes both k_c and - # pe (i.e. decoupled position embeddings). In particular, - # the concat_and_cache_mla op requires - # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) - # i.e. - # kv_lora_rank + qk_rope_head_dim == head_size - self.mla_attn = Attention( + self.mla_attn = MLAAttention( num_heads=self.num_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, scale=scale, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=mla_modules.is_sparse, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, qk_nope_head_dim=self.qk_nope_head_dim, qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", kv_b_proj=self.kv_b_proj, + use_sparse=self.is_sparse, indexer=self.indexer, ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba8d53c0ba14..364b73d6b68d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -14,6 +14,7 @@ from typing_extensions import assert_never from vllm.attention import Attention +from vllm.attention.layer import MLAAttention from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVCrossParallelLinear @@ -122,11 +123,10 @@ def process_weights_after_loading( with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - # Currently only used by MLA. - # NOTE: This intentionally happens after other modules so we can easily - # decompress the weights for MLA. + # Initialize post-load attention weights for both Attention and MLA. + # NOTE: Happens after other modules so we can easily decompress weights. for _, module in model.named_modules(): - if isinstance(module, Attention) and hasattr( + if isinstance(module, (Attention, MLAAttention)) and hasattr( module, "process_weights_after_loading" ): # TODO(lucas): see if there is a way to unify the signatures diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a2fb0cfe6000..d93136701014 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -58,7 +58,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, @@ -1038,7 +1038,7 @@ def __init__( topk_indices_buffer=topk_indices_buffer, ) - self.mla_attn = MultiHeadLatentAttention( + self.mla_attn = MultiHeadLatentAttentionWrapper( self.hidden_size, self.num_local_heads, self.scaling, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index add2c3cb8d59..d2e1aabb7e88 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -32,11 +32,11 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata -from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout, ) from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice @@ -405,7 +405,7 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names) + layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1e1161727be1..d597ce68ffe1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,11 +9,11 @@ import torch import torch.nn as nn -from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache @@ -880,7 +880,7 @@ def get_model_name(self, model: nn.Module) -> str: def load_model(self, target_model: nn.Module) -> None: draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( @@ -897,7 +897,7 @@ def load_model(self, target_model: nn.Module) -> None: ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() + get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() - target_attn_layer_names ) indexer_layers = get_layers_from_vllm_config( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b7dc2287b79f..cbac67d9e24e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -20,6 +20,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper @@ -4388,98 +4389,100 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: use_mla = self.vllm_config.model_config.use_mla cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue + if isinstance(attn_module, Attention): + if ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for slidingwindow" - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) - elif use_mla: - kv_cache_spec[layer_name] = MLAAttentionSpec( + # TODO(lucas): move the attention specs into the model layers like + # the attention backends + if attn_module.attn_type == AttentionType.DECODER: + if attn_module.sliding_window is not None: + assert not use_mla, "MLA is not supported for slidingwindow" + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str, ) - elif self.attention_chunk_size is not None and isinstance( - attn_module, ChunkedLocalAttention + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, ): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - ) + # encoder-only attention does not need KV cache. + continue else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + + elif isinstance(attn_module, MLAAttention): + kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, + num_kv_heads=1, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str, ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if ( - self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"] - ): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet." - ) - mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded - for layer_name, mamba_module in mamba_layers.items(): + elif isinstance(attn_module, MambaBase): + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): + raise NotImplementedError( + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = self.vllm_config.cache_config.mamba_block_size + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), + shapes=attn_module.get_state_shape(), + dtypes=attn_module.get_state_dtype(), block_size=mamba_block_size, page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, + mamba_type=attn_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens if self.speculative_config else 0 ), ) + ds_indexer_layers = get_layers_from_vllm_config( self.vllm_config, DeepseekV32IndexerCache ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1d53fa954a7f..7877f288c2ec 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.attention import Attention from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import MLAAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import ( @@ -32,6 +33,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.models.interfaces import ( @@ -63,6 +65,7 @@ FullAttentionSpec, KVCacheConfig, KVCacheSpec, + MLAAttentionSpec, SlidingWindowSpec, ) from vllm.v1.outputs import ( @@ -561,52 +564,71 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - layers = get_layers_from_vllm_config(self.vllm_config, Attention) + layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase) block_size = self.vllm_config.cache_config.block_size + cache_dtype_str = self.vllm_config.cache_config.cache_dtype + kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): - if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue + # Classic Attention path + if isinstance(attn_module, Attention): + if ( + kv_tgt_layer := attn_module.kv_sharing_target_layer_name + ) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue - if attn_module.attn_type == AttentionType.DECODER: - if isinstance(attn_module, ChunkedLocalAttention): - logger.warning_once( - "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context." - ) - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) + if attn_module.attn_type == AttentionType.DECODER: + if isinstance(attn_module, ChunkedLocalAttention): + logger.warning_once( + "Using irope in Pallas is not supported yet, it " + "will fall back to global attention for long context." + ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): + # encoder-only attention does not need KV cache. + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + # MLAAttention path + elif isinstance(attn_module, MLAAttention): + if layer_name in kv_cache_spec: + continue + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str, + ) else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") + continue return kv_cache_spec