@@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
6455
6455
half * sq = (half *) (__flash_attn_f16_shmem + 0 *D); // holds the query data
6456
6456
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0 *D); // same as above but in half2
6457
6457
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1 *D); // scratch buffer for attention and diagonal matrix
6458
+
6459
+ half16x16_acc zr;
6458
6460
half16x16_acc lo[Q16][D16];
6459
6461
6460
6462
// load heads from Q to shared memory
@@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
6470
6472
}
6471
6473
}
6472
6474
6475
+ nvcuda::wmma::fill_fragment (zr, 0.0 );
6476
+
6473
6477
// zero out lo
6474
6478
for (int64_t j = 0 ; j < Q16; ++j) {
6475
6479
for (int64_t i = 0 ; i < D16; ++i) {
@@ -6487,12 +6491,12 @@ static __global__ void flash_attn_ext_f16(
6487
6491
__syncthreads ();
6488
6492
6489
6493
{
6490
- float S[Q];
6491
- float M[Q];
6494
+ half S[Q];
6495
+ half M[Q];
6492
6496
6493
6497
for (int i = 0 ; i < Q; i++) {
6494
- S[i] = 0 .0f ;
6495
- M[i] = -INFINITY;
6498
+ S[i] = __float2half ( 0 .0f ) ;
6499
+ M[i] = __float2half ( -INFINITY) ;
6496
6500
}
6497
6501
6498
6502
// assume K and V are same shape
@@ -6526,11 +6530,16 @@ static __global__ void flash_attn_ext_f16(
6526
6530
}
6527
6531
}
6528
6532
6529
- const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6530
-
6531
6533
// pointer to the mask
6532
6534
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr ;
6533
6535
6536
+ // prepare diagonal scale matrix
6537
+ half16x16_b mscale;
6538
+ for (int i = 0 ; i < 16 ; ++i) {
6539
+ ss[i*T + i] = __float2half (scale);
6540
+ }
6541
+ nvcuda::wmma::load_matrix_sync (mscale, ss, T);
6542
+
6534
6543
// loop over the KV cache
6535
6544
// each simdgroup handles blocks of Q rows and C columns
6536
6545
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
@@ -6555,111 +6564,129 @@ static __global__ void flash_attn_ext_f16(
6555
6564
6556
6565
// mqk = mqk*scale + mask
6557
6566
for (int64_t j = 0 ; j < Q16; ++j) {
6558
- for (uint32_t i = 0 ; i < mqk[j].num_elements ; i++) {
6559
- // TODO: process mask
6560
- mqk[j].x [i] = __float2half (scale) * mqk[j].x [i];
6561
- }
6567
+ half16x16_a mqka;
6568
+ half16x16_acc mm;
6569
+
6570
+ // convert accumulator to matrix_a
6571
+ nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6572
+ nvcuda::wmma::load_matrix_sync (mqka, ss + 16 *j*T + 16 *cc, T);
6573
+
6574
+ nvcuda::wmma::load_matrix_sync (mm, mp + 16 *j*(nb31/sizeof (half)) + ic + 16 *cc, nb31/sizeof (half), nvcuda::wmma::mem_row_major);
6575
+ nvcuda::wmma::mma_sync (mqk[j], mqka, mscale, mm);
6562
6576
nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6563
6577
}
6564
6578
}
6565
6579
}
6566
6580
6567
6581
// used to detect blocks full of -INF
6568
- float smax = -INFINITY;
6582
+ half smax = __float2half ( -INFINITY) ;
6569
6583
6570
6584
// online softmax
6571
6585
if (C == 32 ) {
6572
6586
for (int64_t j = 0 ; j < Q; ++j) {
6573
6587
const int64_t p = lane_id;
6574
6588
6575
- const float m = M[j];
6576
- const float s = __half2float ( ss[j*T + p]) ;
6589
+ const half m = M[j];
6590
+ const half s = ss[j*T + p];
6577
6591
6578
- smax = warp_reduce_max (max (smax, s));
6579
- M[j] = warp_reduce_max (max (M[j], s));
6592
+ smax = warp_reduce_max (__hmax (smax, s));
6593
+ M[j] = warp_reduce_max (__hmax (M[j], s));
6580
6594
6581
- const float ms = m == -INFINITY ? 0 .0f : expf (m - M[j]);
6582
- const float vs = s == -INFINITY ? 0 .0f : expf (s - M[j]);
6595
+ const half ms = __hisinf (m) ? __float2half ( 0 .0f ) : hexp (m - M[j]);
6596
+ const half vs = __hisinf (s) ? __float2half ( 0 .0f ) : hexp (s - M[j]);
6583
6597
6584
6598
S[j] = S[j]*ms + warp_reduce_sum (vs);
6585
6599
6586
6600
// create a QxQ diagonal matrix for rescaling the output
6587
6601
if (p == j) {
6588
- ss[j*T + C + j] = __float2half (ms) ;
6602
+ ss[j*T + C + j] = ms ;
6589
6603
}
6590
6604
6591
6605
// the P matrix from the paper (Q rows, C columns)
6592
- ss[j*T + p] = __float2half (vs) ;
6606
+ ss[j*T + p] = vs ;
6593
6607
}
6594
6608
} else {
6595
6609
for (int64_t j = 0 ; j < Q; ++j) {
6596
- const float m = M[j];
6610
+ const half m = M[j];
6597
6611
6598
6612
for (int64_t p = lane_id; p < C; p += NW) {
6599
- const float s = __half2float ( ss[j*T + p]) ;
6613
+ const half s = ss[j*T + p];
6600
6614
6601
- smax = warp_reduce_max ( max ( smax, s) );
6602
- M[j] = warp_reduce_max ( max ( M[j], s) );
6615
+ smax = __hmax ( smax, s);
6616
+ M[j] = __hmax ( M[j], s);
6603
6617
}
6604
6618
6605
- const float ms = m == -INFINITY ? 0 .0f : expf (m - M[j]);
6619
+ smax = warp_reduce_max (smax);
6620
+ M[j] = warp_reduce_max (M[j]);
6606
6621
6607
- S[j] = S [j]*ms ;
6622
+ const half ms = __hisinf (m) ? __float2half ( 0 . 0f ) : hexp (m - M [j]) ;
6608
6623
6609
6624
// create a QxQ diagonal matrix for rescaling the output
6610
6625
if (lane_id == j) {
6611
- ss[j*T + C + j] = __float2half (ms) ;
6626
+ ss[j*T + C + j] = ms ;
6612
6627
}
6613
6628
6629
+ // local sum
6630
+ half ls = 0 .0f ;
6631
+
6614
6632
for (int64_t p = lane_id; p < C; p += NW) {
6615
- const float s = __half2float ( ss[j*T + p]) ;
6633
+ const half s = ss[j*T + p];
6616
6634
6617
- const float vs = s == -INFINITY ? 0 .0f : expf (s - M[j]);
6635
+ const half vs = __hisinf (s) ? __float2half ( 0 .0f ) : hexp (s - M[j]);
6618
6636
6619
- S[j] = S[j] + warp_reduce_sum (vs) ;
6637
+ ls += vs ;
6620
6638
6621
6639
// the P matrix from the paper (Q rows, C columns)
6622
- ss[j*T + p] = __float2half (vs) ;
6640
+ ss[j*T + p] = vs ;
6623
6641
}
6642
+
6643
+ S[j] = S[j]*ms + warp_reduce_sum (ls);
6624
6644
}
6625
6645
}
6626
6646
6627
6647
// skip -INF blocks
6628
- if (smax == -INFINITY ) {
6648
+ if (__hisinf ( smax) ) {
6629
6649
continue ;
6630
6650
}
6631
6651
6632
6652
// O = diag(ms)*O
6633
6653
for (int64_t j = 0 ; j < Q16; ++j) {
6634
- // half16x16_a mm;
6635
- // half16x16_b zro ;
6654
+ half16x16_a mm;
6655
+ half16x16_b lob ;
6636
6656
6637
- // nvcuda::wmma::fill_fragment(zro, 0.0);
6638
- // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6657
+ nvcuda::wmma::load_matrix_sync (mm, ss + 16 *j*T + C + 16 *j, T);
6639
6658
6640
6659
for (int64_t i = 0 ; i < D16; ++i) {
6641
- // nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6642
- for (uint32_t k = 0 ; k < 16 *16 ; k++) {
6643
- half tmp = ss[(16 *j + k%16 )*T + C + 16 *j + k%16 ];
6644
- lo[j][i].x [k] = tmp * lo[j][i].x [k];
6645
- }
6660
+ // convert accumulator to matrix_b
6661
+ nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + 16 *j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6662
+ nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + 16 *j, T);
6663
+
6664
+ nvcuda::wmma::fill_fragment (lo[j][i], 0.0 );
6665
+ nvcuda::wmma::mma_sync (lo[j][i], mm, lob, lo[j][i]);
6646
6666
}
6667
+
6668
+ // restore zeros
6669
+ nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + C + 16 *j, zr, T, nvcuda::wmma::mem_row_major);
6647
6670
}
6648
6671
6649
6672
// O = O + (Q*K^T)*V
6650
6673
{
6651
6674
for (int cc = 0 ; cc < C/16 ; ++cc) {
6652
6675
const half * pv = (const half *) ((const char *) v + ((ic + 16 *cc)*nb21 + iv2*nb22 + iv3*nb23));
6653
6676
6677
+ half16x16_b mk[D16];
6654
6678
for (int64_t i = 0 ; i < D16; ++i) {
6655
- half16x16_b mk ;
6656
- nvcuda::wmma::load_matrix_sync (mk, pv + i* 16 , nb21/ sizeof (half));
6679
+ nvcuda::wmma::load_matrix_sync (mk[i], pv + i* 16 , nb21/ sizeof (half)) ;
6680
+ }
6657
6681
6658
- for (int64_t j = 0 ; j < Q16; ++j) {
6659
- half16x16_a mv;
6660
- nvcuda::wmma::load_matrix_sync (mv, ss + 16 *j*T + 16 *cc, T);
6682
+ half16x16_a mv[Q16];
6683
+ for (int64_t j = 0 ; j < Q16; ++j) {
6684
+ nvcuda::wmma::load_matrix_sync (mv[j], ss + 16 *j*T + 16 *cc, T);
6685
+ }
6661
6686
6662
- nvcuda::wmma::mma_sync (lo[j][i], mv, mk, lo[j][i]);
6687
+ for (int64_t j = 0 ; j < Q16; ++j) {
6688
+ for (int64_t i = 0 ; i < D16; ++i) {
6689
+ nvcuda::wmma::mma_sync (lo[j][i], mv[j], mk[i], lo[j][i]);
6663
6690
}
6664
6691
}
6665
6692
}
@@ -6669,16 +6696,16 @@ static __global__ void flash_attn_ext_f16(
6669
6696
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6670
6697
for (int64_t j = 0 ; j < Q; ++j) {
6671
6698
if (lane_id == 0 ) {
6672
- ss[j*T + 0 ] = __float2half ( S[j]) ;
6673
- ss[j*T + 1 ] = __float2half ( M[j]) ;
6699
+ ss[j*T + 0 ] = S[j];
6700
+ ss[j*T + 1 ] = M[j];
6674
6701
}
6675
6702
}
6676
6703
}
6677
6704
6678
6705
// reduce the warps sequentially
6679
6706
for (int64_t sg = 1 ; sg < num_warps; ++sg) {
6680
- float S = 0 .0f ;
6681
- float M = -INFINITY;
6707
+ half S = __float2half ( 0 .0f ) ;
6708
+ half M = __float2half ( -INFINITY) ;
6682
6709
6683
6710
__syncthreads ();
6684
6711
@@ -6696,25 +6723,25 @@ static __global__ void flash_attn_ext_f16(
6696
6723
// the first simdgroup accumulates the results from the other simdgroups
6697
6724
if (warp_id == 0 ) {
6698
6725
for (int64_t j = 0 ; j < Q; ++j) {
6699
- const float S0 = __half2float ( ss[j*T + 0 ]) ;
6700
- const float S1 = __half2float ( ss[j*T + sg*SH + 0 ]) ;
6726
+ const half S0 = ss[j*T + 0 ];
6727
+ const half S1 = ss[j*T + sg*SH + 0 ];
6701
6728
6702
- const float M0 = __half2float ( ss[j*T + 1 ]) ;
6703
- const float M1 = __half2float ( ss[j*T + sg*SH + 1 ]) ;
6729
+ const half M0 = ss[j*T + 1 ];
6730
+ const half M1 = ss[j*T + sg*SH + 1 ];
6704
6731
6705
- M = max (M0, M1);
6732
+ M = __hmax (M0, M1);
6706
6733
6707
- const float ms0 = M0 == -INFINITY ? 0 .0f : expf (M0 - M);
6708
- const float ms1 = M1 == -INFINITY ? 0 .0f : expf (M1 - M);
6734
+ const half ms0 = __hisinf (M0) ? __float2half ( 0 .0f ) : hexp (M0 - M);
6735
+ const half ms1 = __hisinf (M1) ? __float2half ( 0 .0f ) : hexp (M1 - M);
6709
6736
6710
6737
S = S0*ms0 + S1*ms1;
6711
6738
6712
6739
if (lane_id == 0 ) {
6713
- ss[j*T + 0 ] = __float2half (S) ;
6714
- ss[j*T + 1 ] = __float2half (M) ;
6740
+ ss[j*T + 0 ] = S ;
6741
+ ss[j*T + 1 ] = M ;
6715
6742
6716
- ss[j*T + C + j ] = __float2half ( ms0) ;
6717
- ss[j*T + C + j + sg*SH] = __float2half ( ms1) ;
6743
+ ss[j*T + C + j ] = ms0;
6744
+ ss[j*T + C + j + sg*SH] = ms1;
6718
6745
}
6719
6746
}
6720
6747
@@ -6732,10 +6759,11 @@ static __global__ void flash_attn_ext_f16(
6732
6759
nvcuda::wmma::fill_fragment (t2, 0.0 );
6733
6760
nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6734
6761
nvcuda::wmma::mma_sync (t2, ms1, t, t2);
6735
- // store temporally 'lo' data
6736
- nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6737
- // load 'lo' data into t
6738
- nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6762
+
6763
+ // convert accumulator to matrix_b
6764
+ nvcuda::wmma::store_matrix_sync ( sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6765
+ nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6766
+
6739
6767
nvcuda::wmma::mma_sync (lo[j][i], ms0, t, t2);
6740
6768
}
6741
6769
}
@@ -6751,15 +6779,13 @@ static __global__ void flash_attn_ext_f16(
6751
6779
}
6752
6780
}
6753
6781
6754
- // float2 * dst2 = (float2 *) dst;
6755
-
6756
6782
// final rescale with 1/S and store to global memory
6757
6783
if (warp_id == 0 ) {
6758
6784
for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6759
- const float S = __half2float ( ss[j*T + 0 ]) ;
6785
+ const half S = ss[j*T + 0 ];
6760
6786
6761
6787
for (int64_t i = lane_id; i < D; i += NW) {
6762
- dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i]) / S;
6788
+ dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i] / S) ;
6763
6789
}
6764
6790
}
6765
6791
}
@@ -9618,7 +9644,7 @@ static void ggml_cuda_op_soft_max(
9618
9644
9619
9645
const int64_t ne00 = src0->ne [0 ];
9620
9646
const int64_t nrows_x = ggml_nrows (src0);
9621
- const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
9647
+ const int64_t nrows_y = src1 ? src0-> ne [ 1 ] : 1 ; // note: using number of queries since mask can be padded!
9622
9648
9623
9649
float scale = 1 .0f ;
9624
9650
memcpy (&scale, dst->op_params , sizeof (float ));
@@ -10897,8 +10923,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10897
10923
10898
10924
GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
10899
10925
GGML_ASSERT (!mask || mask->backend == GGML_BACKEND_GPU);
10900
- GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 8 ) &&
10901
- " the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big" );
10926
+ GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
10927
+ " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
10902
10928
10903
10929
ggml_cuda_set_device (g_main_device);
10904
10930
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
@@ -10912,19 +10938,25 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10912
10938
float scale;
10913
10939
memcpy (&scale, KQV->op_params , sizeof (float ));
10914
10940
10915
- const int nqpb = 16 ; // queries per block
10916
- const int ncpw = 32 ; // cache values per warp (does not work for other values)
10917
- // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10918
- const int nwarps = 1 ;
10941
+ #define NQPB 16
10942
+ #define NCPW 128
10943
+
10944
+ const int nqpb = NQPB; // queries per block
10945
+ const int ncpw = NCPW; // cache values per warp (does not work for other values)
10946
+
10947
+ const int nwarps_max = 8 ; // TODO: we don't want to launch too much warps. how much is too much?
10948
+ // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10949
+ const int nwarps = Q->ne [1 ] <= nqpb ? MAX (2 , MIN (K->ne [1 ]/ncpw, nwarps_max)) : 2 ;
10919
10950
10920
10951
dim3 blocks_num ((Q->ne [1 ] + nqpb - 1 ) / nqpb, Q->ne [2 ], Q->ne [3 ]);
10921
10952
dim3 block_dim (32 , nwarps, 1 );
10922
10953
10923
- int shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10954
+ const size_t shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10955
+
10924
10956
switch (Q->ne [0 ])
10925
10957
{
10926
10958
case 16 :
10927
- flash_attn_ext_f16<16 , 16 , 32 >
10959
+ flash_attn_ext_f16<16 , NQPB, NCPW >
10928
10960
<<<blocks_num, block_dim, shmem, main_stream>>> (
10929
10961
(const char *) src0_extra->data_device [g_main_device], // Query
10930
10962
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10941,7 +10973,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10941
10973
);
10942
10974
break ;
10943
10975
case 64 :
10944
- flash_attn_ext_f16<64 , 16 , 32 >
10976
+ flash_attn_ext_f16<64 , NQPB, NCPW >
10945
10977
<<<blocks_num, block_dim, shmem, main_stream>>> (
10946
10978
(const char *) src0_extra->data_device [g_main_device], // Query
10947
10979
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10958,7 +10990,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10958
10990
);
10959
10991
break ;
10960
10992
case 80 :
10961
- flash_attn_ext_f16<80 , 16 , 32 >
10993
+ flash_attn_ext_f16<80 , NQPB, NCPW >
10962
10994
<<<blocks_num, block_dim, shmem, main_stream>>> (
10963
10995
(const char *) src0_extra->data_device [g_main_device], // Query
10964
10996
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10975,7 +11007,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10975
11007
);
10976
11008
break ;
10977
11009
case 128 :
10978
- flash_attn_ext_f16<128 , 16 , 32 >
11010
+ flash_attn_ext_f16<128 , NQPB, NCPW >
10979
11011
<<<blocks_num, block_dim, shmem, main_stream>>> (
10980
11012
(const char *) src0_extra->data_device [g_main_device], // Query
10981
11013
(const char *) src1_extra->data_device [g_main_device], // Key
0 commit comments