@@ -154,28 +154,27 @@ def has_nvidia_artifactory() -> bool:
154154
155155
156156@functools .cache
157- def supports_trtllm_attention () -> tuple [ bool , Optional [ str ]] :
158- """Cache result which only depends on the environment"""
159- # This is a lambda, call it once
160- env_value = envs . VLLM_USE_TRTLLM_ATTENTION
161-
157+ def supports_trtllm_attention () -> bool :
158+ """
159+ TRTLLM attention is supported if the platform is SM100 and
160+ NVIDIA artifactory is accessible
161+ """
162162 # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
163- if not (current_platform .is_device_capability (100 )
164- and has_nvidia_artifactory ()):
165- return False , env_value
163+ return current_platform .is_device_capability (
164+ 100 ) and has_nvidia_artifactory ()
166165
166+
167+ @functools .cache
168+ def force_use_trtllm_attention () -> Optional [bool ]:
169+ """
170+ Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
171+ return ``True`` if TRTLLM attention is forced to be used,
172+ return ``False`` if TRTLLM attention is forced to be not used.
173+ """
174+ env_value = envs .VLLM_USE_TRTLLM_ATTENTION
167175 if env_value is not None :
168176 logger .info_once ("VLLM_USE_TRTLLM_ATTENTION is set to %s" , env_value )
169- # Environment variable is set - respect it
170- # Making the conditional check for zero because
171- # the path is automatically enabled if the batch size condition
172- # is satisfied.
173- use_trtllm = (env_value == "1" )
174- if use_trtllm :
175- logger .info_once ("Using TRTLLM attention." )
176- return use_trtllm , env_value
177-
178- return True , None
177+ return None if env_value is None else env_value == "1"
179178
180179
181180def use_trtllm_attention (
@@ -185,18 +184,38 @@ def use_trtllm_attention(
185184 max_seq_len : int ,
186185 kv_cache_dtype : str ,
187186 q_dtype : torch .dtype ,
188- is_prefill : bool ,
189187 has_sinks : bool = False ,
190188) -> bool :
191- use_trtllm , env_value = supports_trtllm_attention ()
192- if not use_trtllm :
189+ """Return ``True`` if TRTLLM attention is used."""
190+ force_use_trtllm = force_use_trtllm_attention ()
191+
192+ # Environment variable is set to 0 - respect it
193+ if force_use_trtllm is not None and not force_use_trtllm :
193194 return False
194195
196+ # The platform is not supported
197+ if not supports_trtllm_attention ():
198+ if force_use_trtllm :
199+ logger .warning_once (
200+ "TRTLLM attention is not supported on this platform, "
201+ "but VLLM_USE_TRTLLM_ATTENTION is set to 1" )
202+ return False
203+
204+ # The combination of query and key heads is not supported
195205 if num_qo_heads % num_kv_heads != 0 :
206+ if force_use_trtllm :
207+ logger .warning_once (
208+ "TRTLLM attention is not supported for this combination of "
209+ "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
210+ )
196211 return False
197212
198213 # Must use TRTLLM attention if query is FP8 quantized
199214 if q_dtype == current_platform .fp8_dtype ():
215+ if has_sinks :
216+ raise RuntimeError (
217+ "TRTLLM FP8-qkv kernel is not supported for attention sinks. "
218+ "Use kv_cache_dtype=auto for now." )
200219 logger .info_once ("Using TRTLLM attention (query is quantized)." )
201220 return True
202221
@@ -207,7 +226,7 @@ def use_trtllm_attention(
207226 "Using TRTLLM attention (required for attention sinks)." )
208227 return True
209228
210- if env_value is None :
229+ if force_use_trtllm is None :
211230 # Environment variable not set - use auto-detection
212231 use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
213232 and kv_cache_dtype == "auto" )
@@ -216,6 +235,8 @@ def use_trtllm_attention(
216235 return use_trtllm
217236
218237 # Environment variable is set to 1 - respect it
238+ logger .info_once (
239+ "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)" )
219240 return True
220241
221242
@@ -367,6 +388,7 @@ def flashinfer_disable_q_quantization() -> bool:
367388 "has_nvidia_artifactory" ,
368389 "supports_trtllm_attention" ,
369390 "use_trtllm_attention" ,
391+ "flashinfer_disable_q_quantization" ,
370392 "flashinfer_scaled_fp4_mm" ,
371393 "flashinfer_scaled_fp8_mm" ,
372394]
0 commit comments