@@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16(
6491
6491
__syncthreads ();
6492
6492
6493
6493
{
6494
- float S[Q];
6495
- float M[Q];
6494
+ half S[Q];
6495
+ half M[Q];
6496
6496
6497
6497
for (int i = 0 ; i < Q; i++) {
6498
6498
S[i] = 0 .0f ;
@@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16(
6579
6579
}
6580
6580
6581
6581
// used to detect blocks full of -INF
6582
- float smax = -INFINITY;
6582
+ half smax = -INFINITY;
6583
6583
6584
6584
// online softmax
6585
6585
if (C == 32 ) {
6586
6586
for (int64_t j = 0 ; j < Q; ++j) {
6587
6587
const int64_t p = lane_id;
6588
6588
6589
- const float m = M[j];
6590
- const float s = __half2float( ss[j*T + p]) ;
6589
+ const half m = M[j];
6590
+ const half s = ss[j*T + p];
6591
6591
6592
- smax = warp_reduce_max(max (smax, s));
6593
- M[j] = warp_reduce_max(max (M[j], s));
6592
+ smax = warp_reduce_max (__hmax (smax, s));
6593
+ M[j] = warp_reduce_max (__hmax (M[j], s));
6594
6594
6595
- const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6596
- const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6595
+ const half ms = __hisinf (m) ? 0 .0f : expf (m - M[j]);
6596
+ const half vs = __hisinf (s) ? 0 .0f : expf (s - M[j]);
6597
6597
6598
6598
S[j] = S[j]*ms + warp_reduce_sum (vs);
6599
6599
6600
6600
// create a QxQ diagonal matrix for rescaling the output
6601
6601
if (p == j) {
6602
- ss[j*T + C + j] = __float2half(ms) ;
6602
+ ss[j*T + C + j] = ms ;
6603
6603
}
6604
6604
6605
6605
// the P matrix from the paper (Q rows, C columns)
6606
- ss[j*T + p] = __float2half(vs) ;
6606
+ ss[j*T + p] = vs ;
6607
6607
}
6608
6608
} else {
6609
6609
for (int64_t j = 0 ; j < Q; ++j) {
6610
- const float m = M[j];
6610
+ const half m = M[j];
6611
6611
6612
6612
for (int64_t p = lane_id; p < C; p += NW) {
6613
- const float s = __half2float( ss[j*T + p]) ;
6613
+ const half s = ss[j*T + p];
6614
6614
6615
- smax = warp_reduce_max(max (smax, s));
6616
- M[j] = warp_reduce_max(max (M[j], s));
6615
+ smax = warp_reduce_max (__hmax (smax, s));
6616
+ M[j] = warp_reduce_max (__hmax (M[j], s));
6617
6617
}
6618
6618
6619
- const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6619
+ const half ms = __hisinf (m) ? 0 .0f : expf (m - M[j]);
6620
6620
6621
6621
S[j] = S[j]*ms;
6622
6622
6623
6623
// create a QxQ diagonal matrix for rescaling the output
6624
6624
if (lane_id == j) {
6625
- ss[j*T + C + j] = __float2half(ms) ;
6625
+ ss[j*T + C + j] = ms ;
6626
6626
}
6627
6627
6628
6628
for (int64_t p = lane_id; p < C; p += NW) {
6629
- const float s = __half2float( ss[j*T + p]) ;
6629
+ const half s = ss[j*T + p];
6630
6630
6631
- const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6631
+ const half vs = __hisinf (s) ? 0 .0f : expf (s - M[j]);
6632
6632
6633
6633
S[j] = S[j] + warp_reduce_sum (vs);
6634
6634
6635
6635
// the P matrix from the paper (Q rows, C columns)
6636
- ss[j*T + p] = __float2half(vs) ;
6636
+ ss[j*T + p] = vs ;
6637
6637
}
6638
6638
}
6639
6639
}
6640
6640
6641
+
6641
6642
// skip -INF blocks
6642
- if (smax == -INFINITY ) {
6643
+ if (__hisinf ( smax) ) {
6643
6644
continue ;
6644
6645
}
6645
6646
@@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16(
6686
6687
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6687
6688
for (int64_t j = 0 ; j < Q; ++j) {
6688
6689
if (lane_id == 0 ) {
6689
- ss[j*T + 0] = __float2half( S[j]) ;
6690
- ss[j*T + 1] = __float2half( M[j]) ;
6690
+ ss[j*T + 0 ] = S[j];
6691
+ ss[j*T + 1 ] = M[j];
6691
6692
}
6692
6693
}
6693
6694
}
6694
6695
6695
6696
// reduce the warps sequentially
6696
6697
for (int64_t sg = 1 ; sg < num_warps; ++sg) {
6697
- float S = 0.0f;
6698
- float M = -INFINITY;
6698
+ half S = 0 .0f ;
6699
+ half M = -INFINITY;
6699
6700
6700
6701
__syncthreads ();
6701
6702
@@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16(
6713
6714
// the first simdgroup accumulates the results from the other simdgroups
6714
6715
if (warp_id == 0 ) {
6715
6716
for (int64_t j = 0 ; j < Q; ++j) {
6716
- const float S0 = __half2float( ss[j*T + 0]) ;
6717
- const float S1 = __half2float( ss[j*T + sg*SH + 0]) ;
6717
+ const half S0 = ss[j*T + 0 ];
6718
+ const half S1 = ss[j*T + sg*SH + 0 ];
6718
6719
6719
- const float M0 = __half2float( ss[j*T + 1]) ;
6720
- const float M1 = __half2float( ss[j*T + sg*SH + 1]) ;
6720
+ const half M0 = ss[j*T + 1 ];
6721
+ const half M1 = ss[j*T + sg*SH + 1 ];
6721
6722
6722
- M = max (M0, M1);
6723
+ M = __hmax (M0, M1);
6723
6724
6724
- const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
6725
- const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
6725
+ const half ms0 = __hisinf (M0) ? 0 .0f : expf (M0 - M);
6726
+ const half ms1 = __hisinf (M1) ? 0 .0f : expf (M1 - M);
6726
6727
6727
6728
S = S0*ms0 + S1*ms1;
6728
6729
6729
6730
if (lane_id == 0 ) {
6730
- ss[j*T + 0] = __float2half(S) ;
6731
- ss[j*T + 1] = __float2half(M) ;
6731
+ ss[j*T + 0 ] = S ;
6732
+ ss[j*T + 1 ] = M ;
6732
6733
6733
- ss[j*T + C + j ] = __float2half( ms0) ;
6734
- ss[j*T + C + j + sg*SH] = __float2half( ms1) ;
6734
+ ss[j*T + C + j ] = ms0;
6735
+ ss[j*T + C + j + sg*SH] = ms1;
6735
6736
}
6736
6737
}
6737
6738
@@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16(
6774
6775
// final rescale with 1/S and store to global memory
6775
6776
if (warp_id == 0 ) {
6776
6777
for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6777
- const float S = __half2float( ss[j*T + 0]) ;
6778
+ const half S = ss[j*T + 0 ];
6778
6779
6779
6780
for (int64_t i = lane_id; i < D; i += NW) {
6780
- dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
6781
+ dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i] / S) ;
6781
6782
}
6782
6783
}
6783
6784
}
@@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10930
10931
float scale;
10931
10932
memcpy (&scale, KQV->op_params , sizeof (float ));
10932
10933
10933
- const int nqpb = 16; // queries per block
10934
- const int ncpw = 32; // cache values per warp (does not work for other values)
10934
+ #define NQPB 16
10935
+ #define NCPW 32
10936
+
10937
+ const int nqpb = NQPB; // queries per block
10938
+ const int ncpw = NCPW; // cache values per warp (does not work for other values)
10935
10939
10936
10940
const int nwarps_max = 8 ; // TODO: we don't want to launch too much warps. how much is too much?
10937
10941
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10938
- const int nwarps = Q->ne[1] <= nqpb ? MAX(4 , MIN(K->ne[1]/ncpw, nwarps_max)) : 4 ;
10942
+ const int nwarps = Q->ne [1 ] <= nqpb ? MAX (2 , MIN (K->ne [1 ]/ncpw, nwarps_max)) : 2 ;
10939
10943
10940
10944
dim3 blocks_num ((Q->ne [1 ] + nqpb - 1 ) / nqpb, Q->ne [2 ], Q->ne [3 ]);
10941
10945
dim3 block_dim (32 , nwarps, 1 );
@@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10945
10949
switch (Q->ne [0 ])
10946
10950
{
10947
10951
case 16 :
10948
- flash_attn_ext_f16<16, 16, 32 >
10952
+ flash_attn_ext_f16<16 , NQPB, NCPW >
10949
10953
<<<blocks_num, block_dim, shmem, main_stream>>> (
10950
10954
(const char *) src0_extra->data_device [g_main_device], // Query
10951
10955
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10962
10966
);
10963
10967
break ;
10964
10968
case 64 :
10965
- flash_attn_ext_f16<64, 16, 32 >
10969
+ flash_attn_ext_f16<64 , NQPB, NCPW >
10966
10970
<<<blocks_num, block_dim, shmem, main_stream>>> (
10967
10971
(const char *) src0_extra->data_device [g_main_device], // Query
10968
10972
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10979
10983
);
10980
10984
break ;
10981
10985
case 80 :
10982
- flash_attn_ext_f16<80, 16, 32 >
10986
+ flash_attn_ext_f16<80 , NQPB, NCPW >
10983
10987
<<<blocks_num, block_dim, shmem, main_stream>>> (
10984
10988
(const char *) src0_extra->data_device [g_main_device], // Query
10985
10989
(const char *) src1_extra->data_device [g_main_device], // Key
@@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10996
11000
);
10997
11001
break ;
10998
11002
case 128 :
10999
- flash_attn_ext_f16<128, 16, 32 >
11003
+ flash_attn_ext_f16<128 , NQPB, NCPW >
11000
11004
<<<blocks_num, block_dim, shmem, main_stream>>> (
11001
11005
(const char *) src0_extra->data_device [g_main_device], // Query
11002
11006
(const char *) src1_extra->data_device [g_main_device], // Key
0 commit comments