@@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f16(
14
14
float * __restrict__ dst,
15
15
float2 * __restrict__ dst_meta,
16
16
const float scale,
17
+ const float max_bias,
18
+ const float m0,
19
+ const float m1,
20
+ const uint32_t n_head_log2,
17
21
const int ne00,
18
22
const int ne01,
19
23
const int ne02,
@@ -49,6 +53,18 @@ static __global__ void flash_attn_vec_ext_f16(
49
53
const int stride_KV = nb11 / sizeof (half);
50
54
const int stride_KV2 = nb11 / sizeof (half2);
51
55
56
+ half slopeh = __float2half (1 .0f );
57
+
58
+ // ALiBi
59
+ if (max_bias > 0 .0f ) {
60
+ const int h = blockIdx .y ;
61
+
62
+ const float base = h < n_head_log2 ? m0 : m1;
63
+ const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
64
+
65
+ slopeh = __float2half (powf (base, exph));
66
+ }
67
+
52
68
static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
53
69
constexpr int nwarps = D / WARP_SIZE;
54
70
const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
@@ -132,7 +148,7 @@ static __global__ void flash_attn_vec_ext_f16(
132
148
for (int j = 0 ; j < ncols; ++j) {
133
149
sum2[j] = warp_reduce_sum (sum2[j]);
134
150
half sum = __low2half (sum2[j]) + __high2half (sum2[j]);
135
- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
151
+ sum += mask ? slopeh* maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
136
152
137
153
if (ncols == 1 ) {
138
154
kqmax_new = ggml_cuda_hmax (kqmax_new, sum);
@@ -244,8 +260,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
244
260
const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
245
261
const int shmem = 0 ;
246
262
247
- float scale;
248
- memcpy (&scale, KQV->op_params , sizeof (float ));
263
+ float scale = 1 .0f ;
264
+ float max_bias = 0 .0f ;
265
+
266
+ memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
267
+ memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
268
+
269
+ const uint32_t n_head = Q->ne [2 ];
270
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
271
+
272
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
273
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
249
274
250
275
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
251
276
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -254,7 +279,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
254
279
(const char *) V->data ,
255
280
mask ? ((const char *) mask->data ) : nullptr ,
256
281
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
257
- scale,
282
+ scale, max_bias, m0, m1, n_head_log2,
258
283
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
259
284
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
260
285
mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
0 commit comments