|
13 | 13 |
|
14 | 14 | if HAS_TRITON: |
15 | 15 | """ |
16 | | - this function is modified from |
17 | | - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 |
| 16 | + this functions are modified from https://github.com/ModelTC/lightllm |
18 | 17 | """ |
19 | 18 |
|
| 19 | + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py |
20 | 20 | @triton.jit |
21 | 21 | def _context_flash_attention_kernel( |
22 | 22 | Q, |
@@ -145,20 +145,16 @@ def _context_flash_attention_kernel( |
145 | 145 | tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) |
146 | 146 | return |
147 | 147 |
|
148 | | - |
149 | | - |
150 | 148 | @torch.no_grad() |
151 | 149 | def smooth_llama_context_attn_fwd( |
152 | 150 | q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len |
153 | 151 | ): |
154 | | - |
155 | 152 | BLOCK = 128 |
156 | 153 | # shape constraints |
157 | 154 | Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] |
158 | 155 | assert Lq == Lk, "context process only supports equal query, key, value length" |
159 | 156 | assert Lk == Lv, "context process only supports equal query, key, value length" |
160 | 157 | assert Lk in {16, 32, 64, 128} |
161 | | - BLOCK_N = 128 |
162 | 158 | sm_scale = 1.0 / math.sqrt(Lk) |
163 | 159 | batch, head = b_seq_len.shape[0], q.shape[1] |
164 | 160 | grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) |
@@ -203,6 +199,7 @@ def smooth_llama_context_attn_fwd( |
203 | 199 | ) |
204 | 200 | return |
205 | 201 |
|
| 202 | + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py |
206 | 203 | @triton.jit |
207 | 204 | def _token_attn_1_kernel( |
208 | 205 | Q, |
@@ -264,6 +261,7 @@ def _token_attn_1_kernel( |
264 | 261 | tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) |
265 | 262 | return |
266 | 263 |
|
| 264 | + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py |
267 | 265 | @triton.jit |
268 | 266 | def _token_attn_1_alibi_kernel( |
269 | 267 | Q, |
@@ -413,6 +411,7 @@ def token_attn_fwd_1( |
413 | 411 | ) |
414 | 412 | return |
415 | 413 |
|
| 414 | + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py |
416 | 415 | @triton.jit |
417 | 416 | def _token_attn_softmax_fwd( |
418 | 417 | softmax_logics, |
@@ -479,6 +478,7 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, |
479 | 478 | ) |
480 | 479 | return |
481 | 480 |
|
| 481 | + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py |
482 | 482 | @triton.jit |
483 | 483 | def _token_attn_2_kernel( |
484 | 484 | Prob, |
|
0 commit comments