@@ -30,24 +30,16 @@ def get_attn_backend(
30
30
kv_cache_dtype : Optional [str ],
31
31
block_size : int ,
32
32
) -> Type [AttentionBackend ]:
33
- backend = _which_attn_to_use (num_heads , head_size , num_kv_heads ,
34
- sliding_window , dtype , kv_cache_dtype ,
35
- block_size )
33
+ """Determine which attention backend to use and only import
34
+ the selected backend module.
35
+ """
36
+ backend = which_attn_to_use (num_heads , head_size , num_kv_heads ,
37
+ sliding_window , dtype , kv_cache_dtype ,
38
+ block_size )
36
39
if backend == _Backend .FLASH_ATTN :
37
40
from vllm .attention .backends .flash_attn import ( # noqa: F401
38
41
FlashAttentionBackend )
39
-
40
- # We check it here not in _which_attn_to_use because we cannot know
41
- # the head size until we import FlashAttentionBackend.
42
- supported_head_sizes = FlashAttentionBackend .get_supported_head_sizes ()
43
- if head_size in supported_head_sizes :
44
- logger .info ("Using FlashAttention-2 backend." )
45
- return FlashAttentionBackend
46
- logger .info (
47
- "Cannot use FlashAttention-2 backend for head size %d. "
48
- "Using XFormers backend instead." , head_size )
49
- backend = _Backend .XFORMERS
50
-
42
+ return FlashAttentionBackend
51
43
if backend == _Backend .XFORMERS :
52
44
logger .info ("Using XFormers backend." )
53
45
from vllm .attention .backends .xformers import ( # noqa: F401
@@ -64,14 +56,15 @@ def get_attn_backend(
64
56
return TorchSDPABackend
65
57
elif backend == _Backend .FLASHINFER :
66
58
logger .info ("Using Flashinfer backend." )
67
- logger .warning ("Eager mode is enforced for the Flashinfer backend." )
59
+ logger .warning ("Eager mode is required for the Flashinfer backend. "
60
+ "Please make sure --enforce-eager is set." )
68
61
from vllm .attention .backends .flashinfer import FlashInferBackend
69
62
return FlashInferBackend
70
63
else :
71
64
raise ValueError ("Invalid attention backend." )
72
65
73
66
74
- def _which_attn_to_use (
67
+ def which_attn_to_use (
75
68
num_heads : int ,
76
69
head_size : int ,
77
70
num_kv_heads : int ,
@@ -81,54 +74,84 @@ def _which_attn_to_use(
81
74
block_size : int ,
82
75
) -> _Backend :
83
76
"""Returns which flash attention backend to use."""
77
+
78
+ # Default case.
79
+ selected_backend = _Backend .FLASH_ATTN
80
+
81
+ # Check the environment variable and override if specified
82
+ backend_by_env_var : Optional [str ] = envs .VLLM_ATTENTION_BACKEND
83
+ if backend_by_env_var is not None :
84
+ backend_members = _Backend .__members__
85
+ if backend_by_env_var not in backend_members :
86
+ raise ValueError (
87
+ f"Invalid attention backend '{ backend_by_env_var } '. "
88
+ f"Available backends: { ', ' .join (backend_members )} "
89
+ "(case-sensitive)." )
90
+ selected_backend = _Backend [backend_by_env_var ]
91
+
84
92
if is_cpu ():
93
+ if selected_backend != _Backend .TORCH_SDPA :
94
+ logger .info ("Cannot use %s backend on CPU." , selected_backend )
85
95
return _Backend .TORCH_SDPA
86
96
87
97
if is_hip ():
88
98
# AMD GPUs.
89
- if torch .cuda .get_device_capability ()[0 ] != 9 :
90
- # not Instinct series GPUs.
91
- logger .info ("flash_atten is not supported on NAVI GPUs." )
99
+ selected_backend = (_Backend .ROCM_FLASH if selected_backend
100
+ == _Backend .FLASH_ATTN else selected_backend )
101
+ if selected_backend == _Backend .ROCM_FLASH :
102
+ if torch .cuda .get_device_capability ()[0 ] != 9 :
103
+ # not Instinct series GPUs.
104
+ logger .info ("flash_attn is not supported on NAVI GPUs." )
105
+ else :
106
+ logger .info ("%s is not supported in AMD GPUs." , selected_backend )
92
107
return _Backend .ROCM_FLASH
93
108
94
- # NVIDIA GPUs.
95
- if torch .cuda .get_device_capability ()[0 ] < 8 :
96
- # Volta and Turing NVIDIA GPUs.
97
- logger .info ("Cannot use FlashAttention-2 backend for Volta and Turing "
98
- "GPUs." )
99
- return _Backend .XFORMERS
100
-
101
- if dtype not in (torch .float16 , torch .bfloat16 ):
102
- logger .info ("Cannot use FlashAttention-2 backend for dtype other than "
103
- "torch.float16 or torch.bfloat16." )
104
- return _Backend .XFORMERS
105
-
106
- if kv_cache_dtype is not None and kv_cache_dtype .startswith ("fp8" ):
107
- logger .info ("Cannot use FlashAttention-2 backend for FP8 KV cache." )
108
- return _Backend .XFORMERS
109
-
110
- if block_size % 16 != 0 :
111
- logger .info ("Cannot use FlashAttention-2 backend for block size not "
112
- "divisible by 16." )
113
- return _Backend .XFORMERS
114
-
115
- if sliding_window is not None :
116
- logger .info (
117
- "Cannot use FlashAttention-2 backend due to sliding window." )
118
- return _Backend .XFORMERS
119
-
120
- try :
121
- import vllm_flash_attn # noqa: F401
122
- except ImportError :
123
- logger .info (
124
- "Cannot use FlashAttention-2 backend because the vllm_flash_attn "
125
- "package is not found. `pip install vllm-flash-attn` for better "
126
- "performance." )
127
- return _Backend .XFORMERS
128
-
129
- backend_by_env_var = envs .VLLM_ATTENTION_BACKEND
130
- if backend_by_env_var is not None :
131
- return _Backend [backend_by_env_var ]
132
-
133
- # Default case.
134
- return _Backend .FLASH_ATTN
109
+ # FlashAttn in NVIDIA GPUs.
110
+ if selected_backend == _Backend .FLASH_ATTN :
111
+ if torch .cuda .get_device_capability ()[0 ] < 8 :
112
+ # Volta and Turing NVIDIA GPUs.
113
+ logger .info (
114
+ "Cannot use FlashAttention-2 backend for Volta and Turing "
115
+ "GPUs." )
116
+ selected_backend = _Backend .XFORMERS
117
+ elif dtype not in (torch .float16 , torch .bfloat16 ):
118
+ logger .info (
119
+ "Cannot use FlashAttention-2 backend for dtype other than "
120
+ "torch.float16 or torch.bfloat16." )
121
+ selected_backend = _Backend .XFORMERS
122
+ elif kv_cache_dtype is not None and kv_cache_dtype .startswith ("fp8" ):
123
+ logger .info (
124
+ "Cannot use FlashAttention-2 backend for FP8 KV cache." )
125
+ selected_backend = _Backend .XFORMERS
126
+ elif block_size % 16 != 0 :
127
+ logger .info (
128
+ "Cannot use FlashAttention-2 backend for block size not "
129
+ "divisible by 16." )
130
+ selected_backend = _Backend .XFORMERS
131
+ elif sliding_window is not None :
132
+ logger .info (
133
+ "Cannot use FlashAttention-2 backend due to sliding window." )
134
+ selected_backend = _Backend .XFORMERS
135
+
136
+ # FlashAttn is valid for the model, checking if the package is installed.
137
+ if selected_backend == _Backend .FLASH_ATTN :
138
+ try :
139
+ import vllm_flash_attn # noqa: F401
140
+
141
+ from vllm .attention .backends .flash_attn import ( # noqa: F401
142
+ FlashAttentionBackend )
143
+
144
+ supported_sizes = FlashAttentionBackend .get_supported_head_sizes ()
145
+ if head_size not in supported_sizes :
146
+ logger .info (
147
+ "Cannot use FlashAttention-2 backend for head size %d." ,
148
+ head_size )
149
+ selected_backend = _Backend .XFORMERS
150
+ except ImportError :
151
+ logger .info (
152
+ "Cannot use FlashAttention-2 backend because the "
153
+ "vllm_flash_attn package is not found. "
154
+ "`pip install vllm-flash-attn` for better performance." )
155
+ selected_backend = _Backend .XFORMERS
156
+
157
+ return selected_backend
0 commit comments