Skip to content

Commit 5a9da1d

Browse files
committed
refactor trtllm kernel selection logic
Signed-off-by: elvischenv <[email protected]>
1 parent 0faf3cc commit 5a9da1d

File tree

2 files changed

+56
-27
lines changed

2 files changed

+56
-27
lines changed

vllm/utils/flashinfer.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,28 +154,27 @@ def has_nvidia_artifactory() -> bool:
154154

155155

156156
@functools.cache
157-
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
158-
"""Cache result which only depends on the environment"""
159-
# This is a lambda, call it once
160-
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
161-
157+
def supports_trtllm_attention() -> bool:
158+
"""
159+
TRTLLM attention is supported if the platform is SM100 and
160+
NVIDIA artifactory is accessible
161+
"""
162162
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
163-
if not (current_platform.is_device_capability(100)
164-
and has_nvidia_artifactory()):
165-
return False, env_value
163+
return current_platform.is_device_capability(
164+
100) and has_nvidia_artifactory()
166165

166+
167+
@functools.cache
168+
def force_use_trtllm_attention() -> Optional[bool]:
169+
"""
170+
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
171+
return ``True`` if TRTLLM attention is forced to be used,
172+
return ``False`` if TRTLLM attention is forced to be not used.
173+
"""
174+
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
167175
if env_value is not None:
168176
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
169-
# Environment variable is set - respect it
170-
# Making the conditional check for zero because
171-
# the path is automatically enabled if the batch size condition
172-
# is satisfied.
173-
use_trtllm = (env_value == "1")
174-
if use_trtllm:
175-
logger.info_once("Using TRTLLM attention.")
176-
return use_trtllm, env_value
177-
178-
return True, None
177+
return None if env_value is None else env_value == "1"
179178

180179

181180
def use_trtllm_attention(
@@ -185,18 +184,38 @@ def use_trtllm_attention(
185184
max_seq_len: int,
186185
kv_cache_dtype: str,
187186
q_dtype: torch.dtype,
188-
is_prefill: bool,
189187
has_sinks: bool = False,
190188
) -> bool:
191-
use_trtllm, env_value = supports_trtllm_attention()
192-
if not use_trtllm:
189+
"""Return ``True`` if TRTLLM attention is used."""
190+
force_use_trtllm = force_use_trtllm_attention()
191+
192+
# Environment variable is set to 0 - respect it
193+
if force_use_trtllm is not None and not force_use_trtllm:
193194
return False
194195

196+
# The platform is not supported
197+
if not supports_trtllm_attention():
198+
if force_use_trtllm:
199+
logger.warning_once(
200+
"TRTLLM attention is not supported on this platform, "
201+
"but VLLM_USE_TRTLLM_ATTENTION is set to 1")
202+
return False
203+
204+
# The combination of query and key heads is not supported
195205
if num_qo_heads % num_kv_heads != 0:
206+
if force_use_trtllm:
207+
logger.warning_once(
208+
"TRTLLM attention is not supported for this combination of "
209+
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
210+
)
196211
return False
197212

198213
# Must use TRTLLM attention if query is FP8 quantized
199214
if q_dtype == current_platform.fp8_dtype():
215+
if has_sinks:
216+
raise RuntimeError(
217+
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
218+
"Use kv_cache_dtype=auto for now.")
200219
logger.info_once("Using TRTLLM attention (query is quantized).")
201220
return True
202221

@@ -207,7 +226,7 @@ def use_trtllm_attention(
207226
"Using TRTLLM attention (required for attention sinks).")
208227
return True
209228

210-
if env_value is None:
229+
if force_use_trtllm is None:
211230
# Environment variable not set - use auto-detection
212231
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
213232
and kv_cache_dtype == "auto")
@@ -216,6 +235,8 @@ def use_trtllm_attention(
216235
return use_trtllm
217236

218237
# Environment variable is set to 1 - respect it
238+
logger.info_once(
239+
"Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
219240
return True
220241

221242

@@ -367,6 +388,7 @@ def flashinfer_disable_q_quantization() -> bool:
367388
"has_nvidia_artifactory",
368389
"supports_trtllm_attention",
369390
"use_trtllm_attention",
391+
"flashinfer_disable_q_quantization",
370392
"flashinfer_scaled_fp4_mm",
371393
"flashinfer_scaled_fp8_mm",
372394
]

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
282282
assert self.kv_cache_spec.dtype == self.model_config.dtype
283283
self.kv_cache_dtype = self.kv_cache_spec.dtype
284284

285-
if supports_trtllm_attention()[0] and \
285+
# Use model dtype as q dtype when TRTLLM attn is not supported, or
286+
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
287+
# use fp8 q if kv cache is fp8, and will fall back to model dtype
288+
# if TRTLLM attention kernel is not used when building attn metadata
289+
if supports_trtllm_attention() and \
286290
not flashinfer_disable_q_quantization():
287291
self.q_data_type = self.kv_cache_dtype
288292
else:
@@ -298,7 +302,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
298302
self.window_left = self.global_hyperparameters.window_left
299303
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
300304
self.has_sinks = self.global_hyperparameters.has_sinks
301-
if self.has_sinks and not supports_trtllm_attention()[0]:
305+
if self.has_sinks and not supports_trtllm_attention():
302306
raise NotImplementedError(
303307
"FlashInfer backend currently does not support attention "
304308
"sinks, please use trtllm on blackwell or flash attention on "
@@ -477,28 +481,31 @@ def build(self,
477481
paged_kv_last_page_len_np,
478482
)
479483

480-
# Check if any layer uses sinks (requires TRTLLM attention)
481484
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
482485
self.num_kv_heads,
483486
num_prefill_tokens,
484487
max_seq_len,
485488
self.cache_dtype,
486489
self.q_data_type,
487-
is_prefill=True,
488490
has_sinks=self.has_sinks)
489491
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
490492
self.num_kv_heads,
491493
num_decode_tokens,
492494
max_seq_len,
493495
self.cache_dtype,
494496
self.q_data_type,
495-
is_prefill=False,
496497
has_sinks=self.has_sinks)
497498
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
498499
raise NotImplementedError(
499500
"FlashInfer backend currently does not support attention "
500501
"sinks, please use trtllm on blackwell or flash attention on "
501502
"earlier GPUs.")
503+
504+
# If TRTLLM attention is not used, the q quantization is not supported.
505+
# Fall back to use model dtype.
506+
if not (prefill_use_trtllm and decode_use_trtllm):
507+
self.q_data_type = self.model_config.dtype
508+
502509
attn_metadata = FlashInferMetadata(
503510
num_actual_tokens=num_actual_tokens,
504511
q_data_type=self.q_data_type,

0 commit comments

Comments
 (0)