Skip to content

Commit 8ec62d7

Browse files
committed
address comment
Signed-off-by: elvischenv <[email protected]>
1 parent 7b8fb4a commit 8ec62d7

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

vllm/envs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,9 +1170,12 @@ def get_vllm_port() -> Optional[int]:
11701170
"VLLM_USE_CUDNN_PREFILL":
11711171
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
11721172

1173-
# If set to 1, use the TRTLLM attention backend in flashinfer.
1173+
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
1174+
# If set to 0/False, use the default attention backend in flashinfer.
1175+
# If not set, auto-detect the attention backend in flashinfer.
11741176
"VLLM_USE_TRTLLM_ATTENTION":
1175-
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
1177+
lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else
1178+
os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")),
11761179

11771180
# If set to 1, when we use fp8 kv, we do not quantize Q to fp8
11781181
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION":

vllm/utils/flashinfer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,20 @@ def supports_trtllm_attention() -> bool:
165165

166166

167167
@functools.cache
168+
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
169+
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
170+
if env_value is not None:
171+
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
172+
return env_value
173+
174+
168175
def force_use_trtllm_attention() -> Optional[bool]:
169176
"""
170177
Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
171178
return ``True`` if TRTLLM attention is forced to be used,
172179
return ``False`` if TRTLLM attention is forced to be not used.
173180
"""
174-
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
175-
if env_value is not None:
176-
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
177-
return None if env_value is None else env_value == "1"
181+
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
178182

179183

180184
def use_trtllm_attention(

0 commit comments

Comments
 (0)