From e4005821fc8c20be194d55cedf4362dc1b39bea3 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 18 Dec 2024 12:38:15 +0000 Subject: [PATCH 1/4] add alibi test Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 67 ++++++++++++++++++++++------- vllm/attention/backends/xformers.py | 2 - 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fc549d7a7c18..fa766c08509e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -17,6 +17,8 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + from vllm.attention.backends.xformers import _make_alibi_bias + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -56,6 +58,7 @@ def ref_masked_attention( ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: + print("MASK", attn_mask.shape, attn_weights.shape) attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) @@ -345,20 +348,26 @@ def ref_multi_query_kv_attention( key: torch.Tensor, value: torch.Tensor, scale: float, + alibi_bias: Optional[list[torch.Tensor]], dtype: torch.dtype, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs: list[torch.Tensor] = [] + if alibi_bias: + assert len(alibi_bias) == num_seqs for i in range(num_seqs): start_idx = cu_seq_lens[i] end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype) + # Create attention mask. ALiBi already includes a tril causal mask. + if alibi_bias: + attn_mask = alibi_bias[i] + else: + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -372,10 +381,10 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -# TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -386,6 +395,7 @@ def test_multi_query_kv_attention( num_seqs: int, num_heads: tuple[int, int], head_size: int, + use_alibi: bool, dtype: torch.dtype, seed: int, device: str, @@ -414,16 +424,40 @@ def test_multi_query_kv_attention( # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + alibi_bias = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) + output = torch.empty_like(query) + start = 0 + # Dynamic sequence length not supported with custom attn_bias. + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + output[start:end].copy_(out.view_as(query[start:end])) + start += seq_len + # xformers.AttentionBias to Tensor for use in reference impl. + alibi_bias = [ + b.materialize(b.shape, device=device).squeeze() for b in attn_bias + ] + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: @@ -434,6 +468,7 @@ def test_multi_query_kv_attention( key, value, scale, + alibi_bias, dtype, ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 9fa76634e1fc..14c94c9ac4ca 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -788,8 +788,6 @@ def _make_alibi_bias( dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) return attn_biases From 6cc676d7ed5a97dbf03201a82668ea105de21dfd Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 18 Dec 2024 12:46:10 +0000 Subject: [PATCH 2/4] clean up Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fa766c08509e..d9daa7685165 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -58,7 +58,6 @@ def ref_masked_attention( ) -> torch.Tensor: attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() if attn_mask is not None: - print("MASK", attn_mask.shape, attn_weights.shape) attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) From 44b839746e8bbe7cd58ce3931900b2a426ab5899 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 22 Jan 2025 09:43:17 +0000 Subject: [PATCH 3/4] separate alibi test for lighter load Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index d9daa7685165..0d7898a900e4 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -383,7 +383,6 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -394,10 +393,10 @@ def test_multi_query_kv_attention( num_seqs: int, num_heads: tuple[int, int], head_size: int, - use_alibi: bool, dtype: torch.dtype, seed: int, device: str, + use_alibi: bool = False, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -472,4 +471,32 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention_with_alibi( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + return test_multi_query_kv_attention( + num_seqs, + num_heads, + head_size, + dtype, + seed, + device, + use_alibi=True, + ) From 230f022caaddd23b7df17293ab573908493ebf0b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 5 Mar 2025 16:59:58 +0000 Subject: [PATCH 4/4] fix old use of alibi bias matrix Signed-off-by: NickLucche --- tests/kernels/test_prefix_prefill.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c3ac6a37e717..f2c7f2c809e8 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -439,14 +439,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # heads. # # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]) value = value[:, :, None, :].expand(value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]) - + # [seq, num_kv_heads, num_queries_per_kv, dk]=> + # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the + # codebase. We save some time reshaping alibi matrix at runtime. + key = key.reshape(key.shape[0], -1, key.shape[-1]) + value = value.reshape(value.shape[0], -1, value.shape[-1]) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0)