Skip to content

Commit 183dfd1

Browse files
enable alibi in pagedattention
1 parent 97c6134 commit 183dfd1

File tree

6 files changed

+72
-20
lines changed

6 files changed

+72
-20
lines changed

colossalai/inference/modeling/models/nopadding_baichuan.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,21 @@ def forward(
253253
inference_ops.decode_kv_cache_memcpy(
254254
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
255255
)
256+
inference_ops.flash_decoding_attention(
257+
output_tensor,
258+
query_states,
259+
k_cache,
260+
v_cache,
261+
sequence_lengths,
262+
block_tables,
263+
block_size,
264+
kv_seq_len,
265+
fd_inter_tensor.mid_output,
266+
fd_inter_tensor.mid_output_lse,
267+
self.alibi_slopes,
268+
sm_scale,
269+
)
270+
attn_output = output_tensor
256271
else:
257272
if not is_verifier and not self.use_alibi_attn:
258273
decoding_fused_rotary_embedding(
@@ -276,21 +291,21 @@ def forward(
276291
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
277292
)
278293

279-
attn_output = flash_decoding_attention(
280-
q=query_states,
281-
k_cache=k_cache,
282-
v_cache=v_cache,
283-
kv_seq_len=sequence_lengths,
284-
block_tables=block_tables,
285-
block_size=block_size,
286-
max_seq_len_in_batch=kv_seq_len,
287-
output=output_tensor,
288-
mid_output=fd_inter_tensor.mid_output,
289-
mid_output_lse=fd_inter_tensor.mid_output_lse,
290-
alibi_slopes=self.alibi_slopes,
291-
sm_scale=sm_scale,
292-
q_len=q_len,
293-
)
294+
attn_output = flash_decoding_attention(
295+
q=query_states,
296+
k_cache=k_cache,
297+
v_cache=v_cache,
298+
kv_seq_len=sequence_lengths,
299+
block_tables=block_tables,
300+
block_size=block_size,
301+
max_seq_len_in_batch=kv_seq_len,
302+
output=output_tensor,
303+
mid_output=fd_inter_tensor.mid_output,
304+
mid_output_lse=fd_inter_tensor.mid_output_lse,
305+
alibi_slopes=self.alibi_slopes,
306+
sm_scale=sm_scale,
307+
q_len=q_len,
308+
)
294309

295310
attn_output = attn_output.view(-1, self.hidden_size)
296311
attn_output = torch.mm(attn_output, self.o_proj_weight)

colossalai/inference/modeling/models/nopadding_llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def forward(
597597
kv_seq_len,
598598
fd_inter_tensor.mid_output,
599599
fd_inter_tensor.mid_output_lse,
600+
None,
600601
sm_scale,
601602
)
602603
attn_output = output_tensor

examples/inference/benchmark_ops/benchmark_flash_decoding_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def benchmark_flash_decoding_attention(
113113
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
114114
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
115115
sm_scale = 1.0 / (HEAD_SIZE**0.5)
116+
alibi_slopes = None
116117
kv_scale = 1.0
117118

118119
mid_output = torch.empty(
@@ -166,6 +167,7 @@ def benchmark_flash_decoding_attention(
166167
max_seq_len_across_batch,
167168
mid_output,
168169
mid_output_lse,
170+
alibi_slopes,
169171
sm_scale,
170172
)
171173
else:

extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ __global__ void flash_decoding_attention_kernel(
5454
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
5555
const int* __restrict__ context_lens, // [num_tokens]
5656
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
57+
const float* __restrict__ alibi_slopes, // [num_heads]
5758
const int max_seq_len,
5859
const int num_kv_heads,
5960
const float scale,
@@ -90,6 +91,7 @@ __global__ void flash_decoding_attention_kernel(
9091
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
9192

9293
const int context_len = context_lens[seq_idx];
94+
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
9395
const int thread_group_offset = lane % NUM_THREADS_PER_X;
9496
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
9597
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -149,6 +151,7 @@ __global__ void flash_decoding_attention_kernel(
149151

150152
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
151153
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
154+
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
152155
const bool mask = token_idx >= context_len;
153156
logits[token_idx] = mask ? 0.f : qk;
154157
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -246,6 +249,7 @@ __global__ void flash_decoding_attention_kernel(
246249
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
247250
context_lens.data_ptr<int>(), \
248251
block_tables.data_ptr<int>(), \
252+
alibi_slopes_ptr, \
249253
max_context_len, \
250254
num_kv_heads, \
251255
scale, \
@@ -267,7 +271,8 @@ void flash_decoding_attention_v1_launcher(
267271
torch::Tensor& context_lens, // [num_tokens]
268272
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
269273
int max_context_len,
270-
float scale) {
274+
float scale,
275+
const c10::optional<torch::Tensor>& alibi_slopes) {
271276
int num_tokens = query.size(0);
272277
int num_heads = query.size(1);
273278
int head_size = query.size(2);
@@ -289,6 +294,10 @@ void flash_decoding_attention_v1_launcher(
289294
// Keep that in sync with the logic here!
290295
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
291296

297+
const float* alibi_slopes_ptr = alibi_slopes ?
298+
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
299+
: nullptr;
300+
292301
dim3 grid(num_heads, num_tokens, 1);
293302
dim3 block(NUM_THREADS);
294303
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@@ -321,7 +330,8 @@ void flash_decoding_attention_v1_launcher(
321330
context_lens, \
322331
block_tables, \
323332
max_context_len, \
324-
scale);
333+
scale, \
334+
alibi_slopes);
325335

326336
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
327337
// 1, 2, 4, 64, 128, 256.
@@ -352,6 +362,7 @@ void flash_decoding_attention(
352362
int max_context_len,
353363
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
354364
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
365+
const c10::optional<torch::Tensor>& alibi_slopes,
355366
float scale) {
356367
switch (query.scalar_type()) {
357368
case at::ScalarType::Float:

extensions/pybind/inference/inference.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void flash_decoding_attention(
7373
torch::Tensor&
7474
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
7575
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
76-
float scale);
76+
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
7777

7878
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7979
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,

tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import pytest
55
import torch
66

7+
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
78
from colossalai.kernel.kernel_loader import InferenceOpsLoader
89
from colossalai.utils import get_current_device
10+
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
911

1012
inference_ops = InferenceOpsLoader().load()
1113

@@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
6062
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
6163
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
6264
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
65+
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
6366
def test_flash_decoding_attention(
64-
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
67+
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
6568
):
6669
torch.manual_seed(123)
6770
torch.cuda.empty_cache()
@@ -73,6 +76,11 @@ def test_flash_decoding_attention(
7376
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
7477
device = get_current_device()
7578

79+
if use_alibi_slopes:
80+
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
81+
else:
82+
alibi_slopes = None
83+
7684
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
7785
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
7886
)
@@ -91,6 +99,15 @@ def test_flash_decoding_attention(
9199
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
92100
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
93101

102+
if use_alibi_slopes:
103+
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
104+
torch_padding_mask = torch_padding_mask + alibi_mask
105+
106+
if len(torch_padding_mask.size()) == 4:
107+
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
108+
else:
109+
torch_padding_mask = torch_padding_mask[:, -1:, :]
110+
94111
mid_output = torch.empty(
95112
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
96113
)
@@ -146,8 +163,14 @@ def test_flash_decoding_attention(
146163
max_seq_len_across_batch,
147164
mid_output,
148165
mid_output_lse,
166+
alibi_slopes,
149167
sm_scale,
150168
)
169+
170+
# The alibi may introduce relatively large errors
171+
if use_alibi_slopes:
172+
rtol = 1e0
173+
151174
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
152175

153176

@@ -277,5 +300,5 @@ def test_vllm_flash_decoding_attention(
277300
dtype,
278301
) in test_combinations:
279302
test_flash_decoding_attention(
280-
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
303+
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
281304
)

0 commit comments

Comments
 (0)