Skip to content

Commit f8b61bb

Browse files
gau-nernstxuebwang-amd
authored andcommitted
[Bugfix] Enable FP8 KV cache for FlashInfer and Triton backend on non-sm100 GPUs (vllm-project#24577)
Signed-off-by: Thien Tran <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent e0e56c1 commit f8b61bb

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,10 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
530530
supported = flash_attn_supports_fp8()
531531
else:
532532
supported = True
533+
elif attention_backend == "FLASHINFER":
534+
supported = True
535+
elif attention_backend == "TRITON_ATTN_VLLM_V1":
536+
supported = cls.supports_fp8()
533537
return supported
534538

535539
@classmethod

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
202202
else:
203203
assert self.kv_cache_spec.dtype == self.model_config.dtype
204204
self.kv_cache_dtype = self.kv_cache_spec.dtype
205-
self.q_data_type = self.kv_cache_dtype
205+
206+
if supports_trtllm_attention()[0]:
207+
self.q_data_type = self.kv_cache_dtype
208+
else:
209+
self.q_data_type = self.model_config.dtype
206210

207211
self._cascade_wrapper = None # Wrapper for cascade attention
208212

0 commit comments

Comments
 (0)