From 1c31724c7d271c97bfee0a5a2ad4028f07f8c218 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 4 Mar 2025 16:29:03 +0000 Subject: [PATCH 1/3] Standardize quantized kv cache rejection for attention backends Signed-off-by: mgoin --- vllm/attention/backends/abstract.py | 4 ++++ vllm/attention/backends/flash_attn.py | 3 +++ vllm/attention/backends/flashmla.py | 9 ++++++--- vllm/attention/backends/hpu_attn.py | 7 ++++++- vllm/attention/backends/ipex_attn.py | 5 +++-- vllm/attention/backends/pallas.py | 5 +++-- vllm/attention/backends/torch_sdpa.py | 8 ++++++-- vllm/attention/backends/triton_mla.py | 6 ++++-- vllm/v1/attention/backends/flash_attn.py | 6 +++++- vllm/v1/attention/backends/mla/flashmla.py | 9 ++++++--- vllm/v1/attention/backends/mla/triton_mla.py | 7 ++++++- 11 files changed, 52 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5f0a54013540..0cd95e0749d1 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -294,3 +294,7 @@ def forward( output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype != "auto" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5aca10079f9b..d620d27b3d08 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -626,6 +626,9 @@ def __init__( self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "FlashAttention with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 273c69b63ec6..5d0c23093310 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -6,7 +6,8 @@ import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata, @@ -207,6 +208,10 @@ def __init__( "are not implemented for " "FlashMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -215,8 +220,6 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 9eb533685dbd..f948fbc0a109 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -15,7 +15,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, HPUPagedAttentionMetadata) @@ -158,6 +159,10 @@ def __init__( "are not implemented for " "HPUAttentionImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "HPUAttention with FP8 KV cache not yet supported") + def forward( self, layer: AttentionLayer, diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index b4879af4cf20..d3c61ea26a02 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -9,7 +9,8 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -145,7 +146,7 @@ def __init__( raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError( "IPEX backend does not support FP8 KV cache. " "Please use xFormers backend instead.") diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b61dfe63ddca..2ee66ab9e966 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -8,7 +8,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState @@ -119,7 +120,7 @@ def __init__( raise NotImplementedError("Alibi slopes is not supported.") if sliding_window is not None: raise NotImplementedError("Sliding window is not supported.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 25fe6ed95c5d..37dd75da2759 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,11 +7,15 @@ import torch from torch.nn.functional import scaled_dot_product_attention +# yapf conflicts with isort for this block +# yapf: disable from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionType, + is_quantized_kv_cache) +# yapf: enable from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata @@ -427,7 +431,7 @@ def __init__( raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if kv_cache_dtype != "auto": + if is_quantized_kv_cache(kv_cache_dtype): raise NotImplementedError( "Torch SDPA backend does not support FP8 KV cache. " "Please use xFormers backend instead.") diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 08e8226ab04c..aeed2266a96d 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -58,6 +58,10 @@ def __init__( "are not implemented for " "TritonMLAImpl") + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -66,8 +70,6 @@ def _forward_decode( attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8bf7f3587bc0..dcdd3db6dc14 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,7 +7,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.logger import init_logger @@ -180,6 +181,9 @@ def __init__( else: self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttention V1 with FP8 KV cache not yet supported") if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b357d7142410..bf5cd36c76a1 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -5,7 +5,8 @@ import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) @@ -107,6 +108,10 @@ def __init__( "are not implemented for " "FlashMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashMLA V1 with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, @@ -115,8 +120,6 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 FlashMLA not yet supported") q = torch.cat([q_nope, q_pe], dim=-1)\ .unsqueeze(1) # Add seqlen dim of 1 (decode) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3f9b349a5f04..d48a97ab0296 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -4,7 +4,8 @@ import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -61,6 +62,10 @@ def __init__( "are not implemented for " "TritonMLAImpl") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA V1 with FP8 KV cache not yet supported") + def _forward_decode( self, q_nope: torch.Tensor, From 57f83abdb5544c8f6c4a38a58861577b1bb3ebf0 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 4 Mar 2025 16:32:53 +0000 Subject: [PATCH 2/3] Fix flashattn Signed-off-by: mgoin --- vllm/attention/backends/flash_attn.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d620d27b3d08..0e331efa6a39 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,11 +8,15 @@ import torch from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, - AttentionType) + AttentionType, + is_quantized_kv_cache) +# yapf: enable from vllm.attention.backends.utils import ( PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, get_flash_attn_version, @@ -626,7 +630,7 @@ def __init__( self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlashAttention with FP8 KV cache not yet supported") if logits_soft_cap is None: From 78ca742d8a16d68ce0a0d6d5f74bd8e1f45e24e0 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 4 Mar 2025 18:11:25 +0000 Subject: [PATCH 3/3] Fix triton_mla Signed-off-by: mgoin --- vllm/attention/backends/triton_mla.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index aeed2266a96d..61e5c76d9fda 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -4,7 +4,8 @@ import torch -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata) @@ -58,7 +59,7 @@ def __init__( "are not implemented for " "TritonMLAImpl") - if self.kv_cache_dtype.startswith("fp8"): + if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "TritonMLA with FP8 KV cache not yet supported")