Skip to content

Commit b237fa5

Browse files
enable alibi in pagedattention
1 parent 8d2b8b9 commit b237fa5

File tree

6 files changed

+91
-23
lines changed

6 files changed

+91
-23
lines changed

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,21 @@ def forward(
253253
inference_ops.decode_kv_cache_memcpy(
254254
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
255255
)
256+
inference_ops.flash_decoding_attention(
257+
output_tensor,
258+
query_states,
259+
k_cache,
260+
v_cache,
261+
sequence_lengths,
262+
block_tables,
263+
block_size,
264+
kv_seq_len,
265+
fd_inter_tensor.mid_output,
266+
fd_inter_tensor.mid_output_lse,
267+
self.alibi_slopes,
268+
sm_scale,
269+
)
270+
attn_output = output_tensor
256271
else:
257272
if not is_verifier and not self.use_alibi_attn:
258273
decoding_fused_rotary_embedding(
@@ -276,21 +291,21 @@ def forward(
276291
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
277292
)
278293

279-
attn_output = flash_decoding_attention(
280-
q=query_states,
281-
k_cache=k_cache,
282-
v_cache=v_cache,
283-
kv_seq_len=sequence_lengths,
284-
block_tables=block_tables,
285-
block_size=block_size,
286-
max_seq_len_in_batch=kv_seq_len,
287-
output=output_tensor,
288-
mid_output=fd_inter_tensor.mid_output,
289-
mid_output_lse=fd_inter_tensor.mid_output_lse,
290-
alibi_slopes=self.alibi_slopes,
291-
sm_scale=sm_scale,
292-
q_len=q_len,
293-
)
294+
attn_output = flash_decoding_attention(
295+
q=query_states,
296+
k_cache=k_cache,
297+
v_cache=v_cache,
298+
kv_seq_len=sequence_lengths,
299+
block_tables=block_tables,
300+
block_size=block_size,
301+
max_seq_len_in_batch=kv_seq_len,
302+
output=output_tensor,
303+
mid_output=fd_inter_tensor.mid_output,
304+
mid_output_lse=fd_inter_tensor.mid_output_lse,
305+
alibi_slopes=self.alibi_slopes,
306+
sm_scale=sm_scale,
307+
q_len=q_len,
308+
)
294309

295310
attn_output = attn_output.view(-1, self.hidden_size)
296311
attn_output = torch.mm(attn_output, self.o_proj_weight)

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def forward(
597597
kv_seq_len,
598598
fd_inter_tensor.mid_output,
599599
fd_inter_tensor.mid_output_lse,
600+
None,
600601
sm_scale,
601602
)
602603
attn_output = output_tensor

examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def benchmark_flash_decoding_attention(
113113
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
114114
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
115115
sm_scale = 1.0 / (HEAD_SIZE**0.5)
116+
alibi_slopes = None
116117
kv_scale = 1.0
117118

118119
mid_output = torch.empty(
@@ -166,6 +167,7 @@ def benchmark_flash_decoding_attention(
166167
max_seq_len_across_batch,
167168
mid_output,
168169
mid_output_lse,
170+
alibi_slopes,
169171
sm_scale,
170172
)
171173
else:

extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
6767
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
6868
const int* __restrict__ context_lens, // [num_tokens]
6969
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
70+
const float* __restrict__ alibi_slopes, // [num_heads]
7071
const int max_seq_len,
7172
const int num_kv_heads,
7273
const float scale,
@@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
105106
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
106107

107108
const int context_len = context_lens[seq_idx];
109+
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
108110
const int thread_group_offset = lane % NUM_THREADS_PER_X;
109111
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
110112
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
164166

165167
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
166168
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
169+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
167170
const bool mask = token_idx >= context_len;
168171
logits[token_idx] = mask ? 0.f : qk;
169172
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
261264
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
262265
context_lens.data_ptr<int>(), \
263266
block_tables.data_ptr<int>(), \
267+
alibi_slopes_ptr, \
264268
max_context_len, \
265269
num_kv_heads, \
266270
scale, \
@@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
282286
torch::Tensor& context_lens, // [num_tokens]
283287
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
284288
int max_context_len,
285-
float scale) {
289+
float scale,
290+
const c10::optional<torch::Tensor>& alibi_slopes) {
286291
int num_tokens = query.size(0);
287292
int num_heads = query.size(1);
288293
int head_size = query.size(2);
@@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
304309
// Keep that in sync with the logic here!
305310
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
306311

312+
const float* alibi_slopes_ptr = alibi_slopes ?
313+
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
314+
: nullptr;
315+
307316
dim3 grid(num_heads, num_tokens, 1);
308317
dim3 block(NUM_THREADS);
309318
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
336345
context_lens, \
337346
block_tables, \
338347
max_context_len, \
339-
scale);
348+
scale, \
349+
alibi_slopes);
340350

341351
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
342352
// 1, 2, 4, 64, 128, 256.
@@ -367,6 +377,7 @@ void flash_decoding_attention(
367377
int max_context_len,
368378
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
369379
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
380+
const c10::optional<torch::Tensor>& alibi_slopes,
370381
float scale) {
371382

372383

extensions/pybind/inference/inference.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void flash_decoding_attention(
7373
torch::Tensor&
7474
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
7575
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
76-
float scale);
76+
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
7777

7878
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7979
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,

tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import pytest
55
import torch
66

7+
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
78
from colossalai.kernel.kernel_loader import InferenceOpsLoader
89
from colossalai.utils import get_current_device
10+
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
911

1012
inference_ops = InferenceOpsLoader().load()
1113

@@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
6062
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
6163
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
6264
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
65+
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
6366
def test_flash_decoding_attention(
64-
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
67+
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
6568
):
6669
torch.manual_seed(123)
6770
torch.cuda.empty_cache()
@@ -73,6 +76,11 @@ def test_flash_decoding_attention(
7376
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
7477
device = get_current_device()
7578

79+
if use_alibi_slopes:
80+
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
81+
else:
82+
alibi_slopes = None
83+
7684
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
7785
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
7886
)
@@ -91,6 +99,15 @@ def test_flash_decoding_attention(
9199
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
92100
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
93101

102+
if use_alibi_slopes:
103+
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
104+
torch_padding_mask = torch_padding_mask + alibi_mask
105+
106+
if len(torch_padding_mask.size()) == 4:
107+
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
108+
else:
109+
torch_padding_mask = torch_padding_mask[:, -1:, :]
110+
94111
mid_output = torch.empty(
95112
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
96113
)
@@ -146,8 +163,14 @@ def test_flash_decoding_attention(
146163
max_seq_len_across_batch,
147164
mid_output,
148165
mid_output_lse,
166+
alibi_slopes,
149167
sm_scale,
150168
)
169+
170+
# The alibi may introduce relatively large errors
171+
if use_alibi_slopes:
172+
rtol = 1e0
173+
151174
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
152175

153176

@@ -168,8 +191,9 @@ def test_flash_decoding_attention(
168191
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
169192
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
170193
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
194+
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
171195
def test_vllm_flash_decoding_attention(
172-
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
196+
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
173197
):
174198
torch.manual_seed(123)
175199
torch.cuda.empty_cache()
@@ -199,6 +223,18 @@ def test_vllm_flash_decoding_attention(
199223
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
200224
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
201225

226+
if use_alibi_slopes:
227+
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
228+
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
229+
torch_padding_mask = torch_padding_mask + alibi_mask
230+
231+
if len(torch_padding_mask.size()) == 4:
232+
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
233+
else:
234+
torch_padding_mask = torch_padding_mask[:, -1:, :]
235+
else:
236+
alibi_slopes = None
237+
202238
if dtype == torch.float16:
203239
rtol = 1e-3
204240
atol = 1e-3
@@ -236,8 +272,6 @@ def test_vllm_flash_decoding_attention(
236272
HEAD_SIZE,
237273
)
238274

239-
alibi_slopes = None
240-
241275
vllm_ops.paged_attention_v1(
242276
output,
243277
q.squeeze(2),
@@ -253,6 +287,11 @@ def test_vllm_flash_decoding_attention(
253287
"auto",
254288
kv_scale,
255289
)
290+
291+
# The alibi may introduce relatively large errors
292+
if use_alibi_slopes:
293+
rtol = 1e0
294+
256295
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
257296

258297

@@ -277,5 +316,5 @@ def test_vllm_flash_decoding_attention(
277316
dtype,
278317
) in test_combinations:
279318
test_flash_decoding_attention(
280-
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
319+
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
281320
)

0 commit comments

Comments
 (0)