File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -26,7 +26,7 @@ def clear_cache():
2626@pytest .mark .parametrize ("device" , ["cpu" , "hip" , "cuda" ])
2727def test_mha_attn_platform (device : str ):
2828 """
29- Test that the attention selector between different platform and device.
29+ Test the attention selector between different platform and device.
3030 """
3131 torch .set_default_dtype (torch .float16 )
3232
@@ -41,7 +41,7 @@ def test_mha_attn_platform(device: str):
4141 else :
4242 with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
4343 attn = MultiHeadAttention (16 , 64 , scale = 1 )
44- assert attn .attn_backend == _Backend .FLASH_ATTN
44+ assert attn .attn_backend == _Backend .XFORMERS
4545
4646 with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
4747 attn = MultiHeadAttention (16 , 72 , scale = 1 )
Original file line number Diff line number Diff line change @@ -210,6 +210,9 @@ def __init__(
210210 self .scale = scale
211211 self .num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
212212
213+ assert self .num_heads % self .num_kv_heads == 0
214+ self .num_queries_per_kv = self .num_heads // self .num_kv_heads
215+
213216 dtype = torch .get_default_dtype ()
214217 attn_backend = get_attn_backend (head_size ,
215218 dtype ,
@@ -240,6 +243,11 @@ def forward(
240243 key = key .view (bsz , kv_len , self .num_kv_heads , self .head_size )
241244 value = value .view (bsz , kv_len , self .num_kv_heads , self .head_size )
242245
246+ if (num_repeat := self .num_queries_per_kv ) > 1 :
247+ # Handle MQA and GQA
248+ key = torch .repeat_interleave (key , num_repeat , dim = 2 )
249+ value = torch .repeat_interleave (value , num_repeat , dim = 2 )
250+
243251 if self .attn_backend == _Backend .XFORMERS :
244252 from xformers import ops as xops
245253
You can’t perform that action at this time.
0 commit comments