Skip to content

Commit 925a57d

Browse files
[kernel] Add flash decoding triton kernel for blocked kv cache (#5249)
* add flash decoding unpad triton kernel * rename flash decoding kernel * add kernel testing (draft) * revise pytest * support kv group (GQA) * (trivial) fix api and pytest * (trivial) func renaming * (trivial) func/file renaming * refactor pytest for attention * (trivial) format and consistent vars of context/decode attn * (trivial) remove test redundancy
1 parent 6b90fa5 commit 925a57d

File tree

6 files changed

+577
-154
lines changed

6 files changed

+577
-154
lines changed

colossalai/kernel/triton/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# There may exist import error even if we have triton installed.
1010
if HAS_TRITON:
1111
from .context_attn_unpad import context_attention_unpadded
12+
from .flash_decoding import flash_decoding_fwd
1213
from .fused_layernorm import layer_norm
1314
from .gptq_triton import gptq_fused_linear_triton
1415
from .no_pad_rotary_embedding import rotary_embedding
1516
from .softmax import softmax
1617

1718
__all__ = [
1819
"context_attention_unpadded",
20+
"flash_decoding_fwd",
1921
"softmax",
2022
"layer_norm",
2123
"gptq_fused_linear_triton",

colossalai/kernel/triton/context_attn_unpad.py

+43-45
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel(
4242
sm_scale,
4343
KV_GROUPS: tl.constexpr,
4444
BLOCK_SIZE: tl.constexpr,
45-
BLOCK_DMODEL: tl.constexpr,
45+
HEAD_DIM: tl.constexpr,
4646
BLOCK_M: tl.constexpr,
4747
BLOCK_N: tl.constexpr,
4848
):
@@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel(
6666
for i in range(0, cur_seq_idx):
6767
prev_seq_len_sum += tl.load(context_lengths + i)
6868

69-
q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70-
kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
69+
offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70+
offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
7171
Q_block_ptr = tl.make_block_ptr(
72-
base=Q + q_offset,
73-
shape=(cur_seq_len, BLOCK_DMODEL),
72+
base=Q + offset_q,
73+
shape=(cur_seq_len, HEAD_DIM),
7474
strides=(stride_qt, stride_qd),
7575
offsets=(block_start_m * BLOCK_M, 0),
76-
block_shape=(BLOCK_M, BLOCK_DMODEL),
76+
block_shape=(BLOCK_M, HEAD_DIM),
7777
order=(1, 0),
7878
)
7979
K_block_ptr = tl.make_block_ptr(
80-
base=K + kv_offset,
81-
shape=(BLOCK_DMODEL, cur_seq_len),
80+
base=K + offset_kv,
81+
shape=(HEAD_DIM, cur_seq_len),
8282
strides=(stride_kd, stride_kt),
8383
offsets=(0, 0),
84-
block_shape=(BLOCK_DMODEL, BLOCK_N),
84+
block_shape=(HEAD_DIM, BLOCK_N),
8585
order=(0, 1),
8686
)
8787
V_block_ptr = tl.make_block_ptr(
88-
base=V + kv_offset,
89-
shape=(cur_seq_len, BLOCK_DMODEL),
88+
base=V + offset_kv,
89+
shape=(cur_seq_len, HEAD_DIM),
9090
strides=(stride_vt, stride_vd),
9191
offsets=(0, 0),
92-
block_shape=(BLOCK_N, BLOCK_DMODEL),
92+
block_shape=(BLOCK_N, HEAD_DIM),
9393
order=(1, 0),
9494
)
9595
O_block_ptr = tl.make_block_ptr(
96-
base=O + q_offset,
97-
shape=(cur_seq_len, BLOCK_DMODEL),
96+
base=O + offset_q,
97+
shape=(cur_seq_len, HEAD_DIM),
9898
strides=(stride_ot, stride_od),
9999
offsets=(block_start_m * BLOCK_M, 0),
100-
block_shape=(BLOCK_M, BLOCK_DMODEL),
100+
block_shape=(BLOCK_M, HEAD_DIM),
101101
order=(1, 0),
102102
)
103103

@@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel(
108108
# as we have BLOCK_M the same size as the block size.
109109
cur_block_table_idx = block_start_m
110110
cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb)
111-
kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
111+
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
112112

113113
offsets_m = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
114114
offsets_n = tl.arange(0, BLOCK_N)
115115
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
116116
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
117-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
117+
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
118118

119119
if block_start_m * BLOCK_M >= cur_seq_len:
120120
return
@@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel(
152152

153153
if cur_head_idx % KV_GROUPS == 0:
154154
# Copy k to corresponding cache block
155-
kd_offsets = tl.arange(0, BLOCK_DMODEL)
156-
kt_offsets = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
157-
k_offsets = K + kv_offset + kd_offsets[:, None] * stride_kd + kt_offsets[None, :] * stride_kt
158-
k = tl.load(k_offsets, mask=kt_offsets[None, :] < cur_seq_len, other=0.0)
159-
kcached_offsets = tl.arange(0, BLOCK_DMODEL)
160-
kcachebs_offsets = tl.arange(0, BLOCK_SIZE)
161-
kcache_offsets = (
155+
offsets_dmodel = tl.arange(0, HEAD_DIM)
156+
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
157+
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
158+
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
159+
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
160+
offsets_kcache = (
162161
KCache
163-
+ kvcache_offset
164-
+ kcached_offsets[:, None] * stride_cached
165-
+ kcachebs_offsets[None, :] * stride_cachebs
162+
+ offset_kvcache
163+
+ offsets_dmodel[:, None] * stride_cached
164+
+ offsets_kcachebs[None, :] * stride_cachebs
166165
)
167-
tl.store(kcache_offsets, k, mask=kcachebs_offsets[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
166+
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
168167
# Copy v to corresponding cache block
169-
vd_offsets = kd_offsets
170-
vt_offsets = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
171-
v_offsets = V + kv_offset + vt_offsets[:, None] * stride_vt + vd_offsets[None, :] * stride_vd
172-
v = tl.load(v_offsets, mask=vt_offsets[:, None] < cur_seq_len, other=0.0)
173-
vcached_offsets = kcached_offsets
174-
vcachebs_offsets = kcachebs_offsets
175-
vcache_offsets = (
168+
offsets_vd = offsets_dmodel
169+
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
170+
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
171+
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
172+
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
173+
offsets_vcache = (
176174
VCache
177-
+ kvcache_offset
178-
+ vcachebs_offsets[:, None] * stride_cachebs
179-
+ vcached_offsets[None, :] * stride_cached
175+
+ offset_kvcache
176+
+ offsets_vcachebs[:, None] * stride_cachebs
177+
+ offsets_dmodel[None, :] * stride_cached
180178
)
181-
tl.store(vcache_offsets, v, mask=vcachebs_offsets[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
179+
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
182180

183181
return
184182

185183

186184
def context_attention_unpadded(
187-
q: torch.Tensor, # [num_tokens, num_heads, head_size]
188-
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
189-
v: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
190-
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
191-
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_size, block_size]
185+
q: torch.Tensor, # [num_tokens, num_heads, head_dim]
186+
k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
187+
v: torch.Tensor, # [num_tokens, num_kv_heads, head_dim]
188+
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
189+
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
192190
context_lengths: torch.Tensor, # [num_seqs]
193191
block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence],
194192
block_size: int,
@@ -254,7 +252,7 @@ def context_attention_unpadded(
254252
sm_scale,
255253
num_kv_group,
256254
block_size,
257-
BLOCK_DMODEL=Lk,
255+
HEAD_DIM=Lk,
258256
BLOCK_M=BLOCK_M,
259257
BLOCK_N=BLOCK_N,
260258
)

0 commit comments

Comments
 (0)