@@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f16(
1414 float * __restrict__ dst,
1515 float2 * __restrict__ dst_meta,
1616 const float scale,
17+ const float max_bias,
18+ const float m0,
19+ const float m1,
20+ const uint32_t n_head_log2,
1721 const int ne00,
1822 const int ne01,
1923 const int ne02,
@@ -49,6 +53,18 @@ static __global__ void flash_attn_vec_ext_f16(
4953 const int stride_KV = nb11 / sizeof (half);
5054 const int stride_KV2 = nb11 / sizeof (half2);
5155
56+ half slopeh = __float2half (1 .0f );
57+
58+ // ALiBi
59+ if (max_bias > 0 .0f ) {
60+ const int h = blockIdx .y ;
61+
62+ const float base = h < n_head_log2 ? m0 : m1;
63+ const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
64+
65+ slopeh = __float2half (powf (base, exph));
66+ }
67+
5268 static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
5369 constexpr int nwarps = D / WARP_SIZE;
5470 const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
@@ -132,7 +148,7 @@ static __global__ void flash_attn_vec_ext_f16(
132148 for (int j = 0 ; j < ncols; ++j) {
133149 sum2[j] = warp_reduce_sum (sum2[j]);
134150 half sum = __low2half (sum2[j]) + __high2half (sum2[j]);
135- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
151+ sum += mask ? slopeh* maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
136152
137153 if (ncols == 1 ) {
138154 kqmax_new = ggml_cuda_hmax (kqmax_new, sum);
@@ -244,8 +260,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
244260 const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
245261 const int shmem = 0 ;
246262
247- float scale;
248- memcpy (&scale, KQV->op_params , sizeof (float ));
263+ float scale = 1 .0f ;
264+ float max_bias = 0 .0f ;
265+
266+ memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
267+ memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
268+
269+ const uint32_t n_head = Q->ne [2 ];
270+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
271+
272+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
273+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
249274
250275 flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
251276 <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -254,7 +279,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
254279 (const char *) V->data ,
255280 mask ? ((const char *) mask->data ) : nullptr ,
256281 parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
257- scale,
282+ scale, max_bias, m0, m1, n_head_log2,
258283 Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
259284 K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
260285 mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
0 commit comments