@@ -42,7 +42,7 @@ def _fwd_context_paged_attention_kernel(
42
42
sm_scale ,
43
43
KV_GROUPS : tl .constexpr ,
44
44
BLOCK_SIZE : tl .constexpr ,
45
- BLOCK_DMODEL : tl .constexpr ,
45
+ HEAD_DIM : tl .constexpr ,
46
46
BLOCK_M : tl .constexpr ,
47
47
BLOCK_N : tl .constexpr ,
48
48
):
@@ -66,38 +66,38 @@ def _fwd_context_paged_attention_kernel(
66
66
for i in range (0 , cur_seq_idx ):
67
67
prev_seq_len_sum += tl .load (context_lengths + i )
68
68
69
- q_offset = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70
- kv_offset = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
69
+ offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh
70
+ offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh
71
71
Q_block_ptr = tl .make_block_ptr (
72
- base = Q + q_offset ,
73
- shape = (cur_seq_len , BLOCK_DMODEL ),
72
+ base = Q + offset_q ,
73
+ shape = (cur_seq_len , HEAD_DIM ),
74
74
strides = (stride_qt , stride_qd ),
75
75
offsets = (block_start_m * BLOCK_M , 0 ),
76
- block_shape = (BLOCK_M , BLOCK_DMODEL ),
76
+ block_shape = (BLOCK_M , HEAD_DIM ),
77
77
order = (1 , 0 ),
78
78
)
79
79
K_block_ptr = tl .make_block_ptr (
80
- base = K + kv_offset ,
81
- shape = (BLOCK_DMODEL , cur_seq_len ),
80
+ base = K + offset_kv ,
81
+ shape = (HEAD_DIM , cur_seq_len ),
82
82
strides = (stride_kd , stride_kt ),
83
83
offsets = (0 , 0 ),
84
- block_shape = (BLOCK_DMODEL , BLOCK_N ),
84
+ block_shape = (HEAD_DIM , BLOCK_N ),
85
85
order = (0 , 1 ),
86
86
)
87
87
V_block_ptr = tl .make_block_ptr (
88
- base = V + kv_offset ,
89
- shape = (cur_seq_len , BLOCK_DMODEL ),
88
+ base = V + offset_kv ,
89
+ shape = (cur_seq_len , HEAD_DIM ),
90
90
strides = (stride_vt , stride_vd ),
91
91
offsets = (0 , 0 ),
92
- block_shape = (BLOCK_N , BLOCK_DMODEL ),
92
+ block_shape = (BLOCK_N , HEAD_DIM ),
93
93
order = (1 , 0 ),
94
94
)
95
95
O_block_ptr = tl .make_block_ptr (
96
- base = O + q_offset ,
97
- shape = (cur_seq_len , BLOCK_DMODEL ),
96
+ base = O + offset_q ,
97
+ shape = (cur_seq_len , HEAD_DIM ),
98
98
strides = (stride_ot , stride_od ),
99
99
offsets = (block_start_m * BLOCK_M , 0 ),
100
- block_shape = (BLOCK_M , BLOCK_DMODEL ),
100
+ block_shape = (BLOCK_M , HEAD_DIM ),
101
101
order = (1 , 0 ),
102
102
)
103
103
@@ -108,13 +108,13 @@ def _fwd_context_paged_attention_kernel(
108
108
# as we have BLOCK_M the same size as the block size.
109
109
cur_block_table_idx = block_start_m
110
110
cur_block_id = tl .load (block_table_ptr + cur_block_table_idx * stride_btb )
111
- kvcache_offset = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
111
+ offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
112
112
113
113
offsets_m = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
114
114
offsets_n = tl .arange (0 , BLOCK_N )
115
115
m_i = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
116
116
l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
117
- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL ], dtype = tl .float32 )
117
+ acc = tl .zeros ([BLOCK_M , HEAD_DIM ], dtype = tl .float32 )
118
118
119
119
if block_start_m * BLOCK_M >= cur_seq_len :
120
120
return
@@ -152,43 +152,41 @@ def _fwd_context_paged_attention_kernel(
152
152
153
153
if cur_head_idx % KV_GROUPS == 0 :
154
154
# Copy k to corresponding cache block
155
- kd_offsets = tl .arange (0 , BLOCK_DMODEL )
156
- kt_offsets = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
157
- k_offsets = K + kv_offset + kd_offsets [:, None ] * stride_kd + kt_offsets [None , :] * stride_kt
158
- k = tl .load (k_offsets , mask = kt_offsets [None , :] < cur_seq_len , other = 0.0 )
159
- kcached_offsets = tl .arange (0 , BLOCK_DMODEL )
160
- kcachebs_offsets = tl .arange (0 , BLOCK_SIZE )
161
- kcache_offsets = (
155
+ offsets_dmodel = tl .arange (0 , HEAD_DIM )
156
+ offsets_kt = block_start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
157
+ offsets_k = K + offset_kv + offsets_dmodel [:, None ] * stride_kd + offsets_kt [None , :] * stride_kt
158
+ k = tl .load (offsets_k , mask = offsets_kt [None , :] < cur_seq_len , other = 0.0 )
159
+ offsets_kcachebs = tl .arange (0 , BLOCK_SIZE )
160
+ offsets_kcache = (
162
161
KCache
163
- + kvcache_offset
164
- + kcached_offsets [:, None ] * stride_cached
165
- + kcachebs_offsets [None , :] * stride_cachebs
162
+ + offset_kvcache
163
+ + offsets_dmodel [:, None ] * stride_cached
164
+ + offsets_kcachebs [None , :] * stride_cachebs
166
165
)
167
- tl .store (kcache_offsets , k , mask = kcachebs_offsets [None , :] < cur_seq_len - block_start_m * BLOCK_SIZE )
166
+ tl .store (offsets_kcache , k , mask = offsets_kcachebs [None , :] < cur_seq_len - block_start_m * BLOCK_SIZE )
168
167
# Copy v to corresponding cache block
169
- vd_offsets = kd_offsets
170
- vt_offsets = block_start_m * BLOCK_N + tl .arange (0 , BLOCK_N )
171
- v_offsets = V + kv_offset + vt_offsets [:, None ] * stride_vt + vd_offsets [None , :] * stride_vd
172
- v = tl .load (v_offsets , mask = vt_offsets [:, None ] < cur_seq_len , other = 0.0 )
173
- vcached_offsets = kcached_offsets
174
- vcachebs_offsets = kcachebs_offsets
175
- vcache_offsets = (
168
+ offsets_vd = offsets_dmodel
169
+ offsets_vt = block_start_m * BLOCK_N + tl .arange (0 , BLOCK_N )
170
+ offsets_v = V + offset_kv + offsets_vt [:, None ] * stride_vt + offsets_vd [None , :] * stride_vd
171
+ v = tl .load (offsets_v , mask = offsets_vt [:, None ] < cur_seq_len , other = 0.0 )
172
+ offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
173
+ offsets_vcache = (
176
174
VCache
177
- + kvcache_offset
178
- + vcachebs_offsets [:, None ] * stride_cachebs
179
- + vcached_offsets [None , :] * stride_cached
175
+ + offset_kvcache
176
+ + offsets_vcachebs [:, None ] * stride_cachebs
177
+ + offsets_dmodel [None , :] * stride_cached
180
178
)
181
- tl .store (vcache_offsets , v , mask = vcachebs_offsets [:, None ] < cur_seq_len - block_start_m * BLOCK_SIZE )
179
+ tl .store (offsets_vcache , v , mask = offsets_vcachebs [:, None ] < cur_seq_len - block_start_m * BLOCK_SIZE )
182
180
183
181
return
184
182
185
183
186
184
def context_attention_unpadded (
187
- q : torch .Tensor , # [num_tokens, num_heads, head_size ]
188
- k : torch .Tensor , # [num_tokens, num_kv_heads, head_size ]
189
- v : torch .Tensor , # [num_tokens, num_kv_heads, head_size ]
190
- k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_size , block_size]
191
- v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_size , block_size]
185
+ q : torch .Tensor , # [num_tokens, num_heads, head_dim ]
186
+ k : torch .Tensor , # [num_tokens, num_kv_heads, head_dim ]
187
+ v : torch .Tensor , # [num_tokens, num_kv_heads, head_dim ]
188
+ k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim , block_size]
189
+ v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim , block_size]
192
190
context_lengths : torch .Tensor , # [num_seqs]
193
191
block_tables : torch .Tensor , # [num_seqs, max_blocks_per_sequence],
194
192
block_size : int ,
@@ -254,7 +252,7 @@ def context_attention_unpadded(
254
252
sm_scale ,
255
253
num_kv_group ,
256
254
block_size ,
257
- BLOCK_DMODEL = Lk ,
255
+ HEAD_DIM = Lk ,
258
256
BLOCK_M = BLOCK_M ,
259
257
BLOCK_N = BLOCK_N ,
260
258
)
0 commit comments