@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
23
23
float * __restrict__ dst,
24
24
float2 * __restrict__ dst_meta,
25
25
const float scale,
26
+ const float max_bias,
27
+ const float m0,
28
+ const float m1,
29
+ const uint32_t n_head_log2,
26
30
const int ne00,
27
31
const int ne01,
28
32
const int ne02,
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
58
62
const int stride_KV = nb11 / sizeof (half);
59
63
const int stride_KV2 = nb11 / sizeof (half2);
60
64
65
+ half slopeh = __float2half (1 .0f );
66
+
67
+ // ALiBi
68
+ if (max_bias > 0 .0f ) {
69
+ const int h = blockIdx .y ;
70
+
71
+ const float base = h < n_head_log2 ? m0 : m1;
72
+ const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
73
+
74
+ slopeh = __float2half (powf (base, exph));
75
+ }
76
+
61
77
static_assert (D % (2 *WARP_SIZE) == 0 , " D not divisible by 2*WARP_SIZE == 64." );
62
78
constexpr int nwarps = D / WARP_SIZE;
63
79
const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
141
157
for (int j = 0 ; j < ncols; ++j) {
142
158
sum2[j] = warp_reduce_sum (sum2[j]);
143
159
half sum = __low2half (sum2[j]) + __high2half (sum2[j]);
144
- sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
160
+ sum += mask ? slopeh* maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half (0 .0f );
145
161
146
162
if (ncols == 1 ) {
147
163
kqmax_new = ggml_cuda_hmax (kqmax_new, sum);
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
249
265
float * __restrict__ dst,
250
266
float2 * __restrict__ dst_meta,
251
267
const float scale,
268
+ const float max_bias,
269
+ const float m0,
270
+ const float m1,
271
+ const uint32_t n_head_log2,
252
272
const int ne00,
253
273
const int ne01,
254
274
const int ne02,
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
305
325
const int stride_Q = nb01 / sizeof (float );
306
326
const int stride_KV = nb11 / sizeof (half);
307
327
328
+ half slopeh = __float2half (1 .0f );
329
+ half2 slope2 = make_half2 (1 .0f , 1 .0f );
330
+
331
+ // ALiBi
332
+ if (max_bias > 0 .0f ) {
333
+ const int h = blockIdx .y ;
334
+
335
+ const float base = h < n_head_log2 ? m0 : m1;
336
+ const int exph = h < n_head_log2 ? h + 1 : 2 *(h - n_head_log2) + 1 ;
337
+
338
+ slopeh = __float2half (powf (base, exph));
339
+ slope2 = make_half2 (slopeh, slopeh);
340
+ }
341
+
308
342
frag_b Q_b[D/16 ][ncols/frag_n];
309
343
310
344
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
421
455
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
422
456
const int k = k0 + threadIdx .x ;
423
457
424
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float (maskh[j*(nb31/sizeof (half)) + k_VKQ_0 + k]) : 0 .0f ;
458
+ KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float (slopeh* maskh[j*(nb31/sizeof (half)) + k_VKQ_0 + k]) : 0 .0f ;
425
459
KQ_max_new = max (KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
426
460
}
427
461
KQ_max_new = warp_reduce_max (KQ_max_new);
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
464
498
for (int k0 = 0 ; k0 < FATTN_KQ_STRIDE/2 ; k0 += WARP_SIZE) {
465
499
const int k = k0 + threadIdx .x ;
466
500
467
- KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
501
+ KQ2_tmp[k0/WARP_SIZE] += mask ? slope2* mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2 (0 .0f , 0 .0f );
468
502
KQ_max_new = ggml_cuda_hmax2 (KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
469
503
}
470
504
KQ_max_new = __half2half2 (warp_reduce_max (ggml_cuda_hmax (__low2half (KQ_max_new), __high2half (KQ_max_new))));
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
710
744
const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
711
745
const int shmem = 0 ;
712
746
713
- float scale;
714
- memcpy (&scale, KQV->op_params , sizeof (float ));
747
+ float scale = 1 .0f ;
748
+ float max_bias = 0 .0f ;
749
+
750
+ memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
751
+ memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
752
+
753
+ const uint32_t n_head = Q->ne [2 ];
754
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
755
+
756
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
757
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
715
758
716
759
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
717
760
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
720
763
(const char *) V->data ,
721
764
mask ? ((const char *) mask->data ) : nullptr ,
722
765
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
723
- scale,
766
+ scale, max_bias, m0, m1, n_head_log2,
724
767
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
725
768
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
726
769
mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
761
804
const dim3 blocks_num (parallel_blocks*(Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block, Q->ne [2 ], Q->ne [3 ]);
762
805
const int shmem = 0 ;
763
806
764
- float scale;
765
- memcpy (&scale, KQV->op_params , sizeof (float ));
807
+ float scale = 1 .0f ;
808
+ float max_bias = 0 .0f ;
809
+
810
+ memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
811
+ memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
812
+
813
+ const uint32_t n_head = Q->ne [2 ];
814
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
815
+
816
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
817
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
766
818
767
819
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride (D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
768
820
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
771
823
(const char *) V->data ,
772
824
mask ? ((const char *) mask->data ) : nullptr ,
773
825
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr , dst_tmp_meta.ptr ,
774
- scale,
826
+ scale, max_bias, m0, m1, n_head_log2,
775
827
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
776
828
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
777
829
mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
837
889
const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
838
890
const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
839
891
840
- const int32_t precision = KQV->op_params [1 ];
892
+ const int32_t precision = KQV->op_params [2 ];
841
893
842
894
if (!fp16_mma_available (cc)) {
843
895
GGML_ASSERT (precision == GGML_PREC_DEFAULT);
0 commit comments