Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 6360c9c

Browse files
comaniacrshaw@neuralmagic.com
authored andcommitted
[Misc] Take user preference in attention selector (vllm-project#4960)
1 parent 2cffbda commit 6360c9c

File tree

3 files changed

+169
-61
lines changed

3 files changed

+169
-61
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.attention.selector import which_attn_to_use
8+
9+
10+
@pytest.mark.parametrize(
11+
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
12+
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
13+
def test_env(name: str, device: str):
14+
"""Test that the attention selector can be set via environment variable.
15+
Note that we do not test FlashAttn because it is the default backend.
16+
"""
17+
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
18+
os.environ["VLLM_ATTENTION_BACKEND"] = name
19+
20+
if device == "cpu":
21+
with patch("vllm.attention.selector.is_cpu", return_value=True):
22+
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
23+
torch.float16, 16)
24+
assert backend.name == "TORCH_SDPA"
25+
elif device == "hip":
26+
with patch("vllm.attention.selector.is_hip", return_value=True):
27+
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
28+
torch.float16, 16)
29+
assert backend.name == "ROCM_FLASH"
30+
else:
31+
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
32+
torch.float16, 16)
33+
assert backend.name == name
34+
35+
if name_backup is not None:
36+
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
37+
38+
39+
def test_flash_attn():
40+
"""Test FlashAttn validation."""
41+
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
42+
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
43+
44+
# Unsupported CUDA arch
45+
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
46+
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
47+
assert backend.name != "FLASH_ATTN"
48+
49+
# Unsupported data type
50+
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
51+
assert backend.name != "FLASH_ATTN"
52+
53+
# Unsupported kv cache data type
54+
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
55+
assert backend.name != "FLASH_ATTN"
56+
57+
# Unsupported block size
58+
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
59+
assert backend.name != "FLASH_ATTN"
60+
61+
# Unsupported sliding window
62+
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
63+
assert backend.name != "FLASH_ATTN"
64+
65+
# flash-attn is not installed
66+
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
67+
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
68+
assert backend.name != "FLASH_ATTN"
69+
70+
# Unsupported head size
71+
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
72+
assert backend.name != "FLASH_ATTN"
73+
74+
if name_backup is not None:
75+
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
76+
77+
78+
def test_invalid_env():
79+
"""Throw an exception if the backend name is invalid."""
80+
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
81+
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
82+
with pytest.raises(ValueError):
83+
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
84+
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def forward(
218218
)
219219

220220
if prefill_meta := attn_metadata.prefill_metadata:
221+
# Prompt run.
221222
assert prefill_meta.block_tables is not None
222223
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
223224
output = flash_attn_varlen_func(

vllm/attention/selector.py

Lines changed: 84 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,16 @@ def get_attn_backend(
3030
kv_cache_dtype: Optional[str],
3131
block_size: int,
3232
) -> Type[AttentionBackend]:
33-
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
34-
sliding_window, dtype, kv_cache_dtype,
35-
block_size)
33+
"""Determine which attention backend to use and only import
34+
the selected backend module.
35+
"""
36+
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
37+
sliding_window, dtype, kv_cache_dtype,
38+
block_size)
3639
if backend == _Backend.FLASH_ATTN:
3740
from vllm.attention.backends.flash_attn import ( # noqa: F401
3841
FlashAttentionBackend)
39-
40-
# We check it here not in _which_attn_to_use because we cannot know
41-
# the head size until we import FlashAttentionBackend.
42-
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
43-
if head_size in supported_head_sizes:
44-
logger.info("Using FlashAttention-2 backend.")
45-
return FlashAttentionBackend
46-
logger.info(
47-
"Cannot use FlashAttention-2 backend for head size %d. "
48-
"Using XFormers backend instead.", head_size)
49-
backend = _Backend.XFORMERS
50-
42+
return FlashAttentionBackend
5143
if backend == _Backend.XFORMERS:
5244
logger.info("Using XFormers backend.")
5345
from vllm.attention.backends.xformers import ( # noqa: F401
@@ -64,14 +56,15 @@ def get_attn_backend(
6456
return TorchSDPABackend
6557
elif backend == _Backend.FLASHINFER:
6658
logger.info("Using Flashinfer backend.")
67-
logger.warning("Eager mode is enforced for the Flashinfer backend.")
59+
logger.warning("Eager mode is required for the Flashinfer backend. "
60+
"Please make sure --enforce-eager is set.")
6861
from vllm.attention.backends.flashinfer import FlashInferBackend
6962
return FlashInferBackend
7063
else:
7164
raise ValueError("Invalid attention backend.")
7265

7366

74-
def _which_attn_to_use(
67+
def which_attn_to_use(
7568
num_heads: int,
7669
head_size: int,
7770
num_kv_heads: int,
@@ -81,54 +74,84 @@ def _which_attn_to_use(
8174
block_size: int,
8275
) -> _Backend:
8376
"""Returns which flash attention backend to use."""
77+
78+
# Default case.
79+
selected_backend = _Backend.FLASH_ATTN
80+
81+
# Check the environment variable and override if specified
82+
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
83+
if backend_by_env_var is not None:
84+
backend_members = _Backend.__members__
85+
if backend_by_env_var not in backend_members:
86+
raise ValueError(
87+
f"Invalid attention backend '{backend_by_env_var}'. "
88+
f"Available backends: {', '.join(backend_members)} "
89+
"(case-sensitive).")
90+
selected_backend = _Backend[backend_by_env_var]
91+
8492
if is_cpu():
93+
if selected_backend != _Backend.TORCH_SDPA:
94+
logger.info("Cannot use %s backend on CPU.", selected_backend)
8595
return _Backend.TORCH_SDPA
8696

8797
if is_hip():
8898
# AMD GPUs.
89-
if torch.cuda.get_device_capability()[0] != 9:
90-
# not Instinct series GPUs.
91-
logger.info("flash_atten is not supported on NAVI GPUs.")
99+
selected_backend = (_Backend.ROCM_FLASH if selected_backend
100+
== _Backend.FLASH_ATTN else selected_backend)
101+
if selected_backend == _Backend.ROCM_FLASH:
102+
if torch.cuda.get_device_capability()[0] != 9:
103+
# not Instinct series GPUs.
104+
logger.info("flash_attn is not supported on NAVI GPUs.")
105+
else:
106+
logger.info("%s is not supported in AMD GPUs.", selected_backend)
92107
return _Backend.ROCM_FLASH
93108

94-
# NVIDIA GPUs.
95-
if torch.cuda.get_device_capability()[0] < 8:
96-
# Volta and Turing NVIDIA GPUs.
97-
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
98-
"GPUs.")
99-
return _Backend.XFORMERS
100-
101-
if dtype not in (torch.float16, torch.bfloat16):
102-
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
103-
"torch.float16 or torch.bfloat16.")
104-
return _Backend.XFORMERS
105-
106-
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
107-
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
108-
return _Backend.XFORMERS
109-
110-
if block_size % 16 != 0:
111-
logger.info("Cannot use FlashAttention-2 backend for block size not "
112-
"divisible by 16.")
113-
return _Backend.XFORMERS
114-
115-
if sliding_window is not None:
116-
logger.info(
117-
"Cannot use FlashAttention-2 backend due to sliding window.")
118-
return _Backend.XFORMERS
119-
120-
try:
121-
import vllm_flash_attn # noqa: F401
122-
except ImportError:
123-
logger.info(
124-
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
125-
"package is not found. `pip install vllm-flash-attn` for better "
126-
"performance.")
127-
return _Backend.XFORMERS
128-
129-
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
130-
if backend_by_env_var is not None:
131-
return _Backend[backend_by_env_var]
132-
133-
# Default case.
134-
return _Backend.FLASH_ATTN
109+
# FlashAttn in NVIDIA GPUs.
110+
if selected_backend == _Backend.FLASH_ATTN:
111+
if torch.cuda.get_device_capability()[0] < 8:
112+
# Volta and Turing NVIDIA GPUs.
113+
logger.info(
114+
"Cannot use FlashAttention-2 backend for Volta and Turing "
115+
"GPUs.")
116+
selected_backend = _Backend.XFORMERS
117+
elif dtype not in (torch.float16, torch.bfloat16):
118+
logger.info(
119+
"Cannot use FlashAttention-2 backend for dtype other than "
120+
"torch.float16 or torch.bfloat16.")
121+
selected_backend = _Backend.XFORMERS
122+
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
123+
logger.info(
124+
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
125+
selected_backend = _Backend.XFORMERS
126+
elif block_size % 16 != 0:
127+
logger.info(
128+
"Cannot use FlashAttention-2 backend for block size not "
129+
"divisible by 16.")
130+
selected_backend = _Backend.XFORMERS
131+
elif sliding_window is not None:
132+
logger.info(
133+
"Cannot use FlashAttention-2 backend due to sliding window.")
134+
selected_backend = _Backend.XFORMERS
135+
136+
# FlashAttn is valid for the model, checking if the package is installed.
137+
if selected_backend == _Backend.FLASH_ATTN:
138+
try:
139+
import vllm_flash_attn # noqa: F401
140+
141+
from vllm.attention.backends.flash_attn import ( # noqa: F401
142+
FlashAttentionBackend)
143+
144+
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
145+
if head_size not in supported_sizes:
146+
logger.info(
147+
"Cannot use FlashAttention-2 backend for head size %d.",
148+
head_size)
149+
selected_backend = _Backend.XFORMERS
150+
except ImportError:
151+
logger.info(
152+
"Cannot use FlashAttention-2 backend because the "
153+
"vllm_flash_attn package is not found. "
154+
"`pip install vllm-flash-attn` for better performance.")
155+
selected_backend = _Backend.XFORMERS
156+
157+
return selected_backend

0 commit comments

Comments
 (0)