@@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16(
6513
6513
half S[Q];
6514
6514
half M[Q];
6515
6515
6516
- for(int i = 0; i < Q; i++ ) {
6516
+ for (int i = 0; i < Q; ++i ) {
6517
6517
S[i] = __float2half(0.0f);
6518
- M[i] = __float2half(-INFINITY) ;
6518
+ M[i] = CUDART_MIN_DENORM_FP16 ;
6519
6519
}
6520
6520
6521
6521
// assume K and V are same shape
@@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16(
6609
6609
half smax = __float2half(-INFINITY);
6610
6610
6611
6611
// online softmax
6612
- if (C == 32) {
6613
- for (int j = 0; j < Q; ++j) {
6614
- const int p = lane_id;
6615
-
6616
- const half m = M[j];
6617
- const half s = ss[j*T + p];
6618
-
6619
- smax = warp_reduce_max(__hmax(smax, s));
6620
- M[j] = warp_reduce_max(__hmax(M[j], s));
6621
-
6622
- const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6623
- const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
6612
+ for (int j = 0; j < Q; ++j) {
6613
+ const half m = M[j];
6624
6614
6625
- S[j] = S[j]*ms + warp_reduce_sum(vs);
6615
+ for (int p0 = 0; p0 < C; p0 += NW) {
6616
+ const int p = p0 + lane_id;
6626
6617
6627
- // create a QxQ diagonal matrix for rescaling the output
6628
- if (p == j) {
6629
- ss[j*T + C + j] = ms;
6630
- }
6618
+ const half s = ss[j*T + p];
6631
6619
6632
- // the P matrix from the paper (Q rows, C columns)
6633
- ss[j*T + p ] = vs ;
6620
+ smax = __hmax(smax, s);
6621
+ M[j ] = __hmax(M[j], s) ;
6634
6622
}
6635
- } else {
6636
- for (int j = 0; j < Q; ++j) {
6637
- const half m = M[j];
6638
-
6639
- for (int p0 = 0; p0 < C; p0 += NW) {
6640
- const int p = p0 + lane_id;
6641
6623
6642
- const half s = ss[j*T + p] ;
6624
+ M[j] = warp_reduce_max(M[j]) ;
6643
6625
6644
- smax = __hmax(smax, s);
6645
- M[j] = __hmax(M[j], s);
6646
- }
6647
-
6648
- M[j] = warp_reduce_max(M[j]);
6649
-
6650
- const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6626
+ const half ms = hexp(m - M[j]);
6651
6627
6652
- // create a QxQ diagonal matrix for rescaling the output
6653
- if (lane_id == j) {
6654
- ss[j*T + C + j] = ms;
6655
- }
6656
-
6657
- // local sum
6658
- half ls = 0.0f;
6628
+ // create a QxQ diagonal matrix for rescaling the output
6629
+ if (lane_id == j) {
6630
+ ss[j*T + C + j] = ms;
6631
+ }
6659
6632
6660
- for (int p0 = 0; p0 < C; p0 += NW) {
6661
- const int p = p0 + lane_id ;
6633
+ // local sum
6634
+ half ls = 0.0f ;
6662
6635
6663
- const half s = ss[j*T + p];
6636
+ for (int p0 = 0; p0 < C; p0 += NW) {
6637
+ const int p = p0 + lane_id;
6664
6638
6665
- const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]) ;
6639
+ const half s = ss[j*T + p] ;
6666
6640
6667
- ls += vs ;
6641
+ const half vs = hexp(s - M[j]) ;
6668
6642
6669
- // the P matrix from the paper (Q rows, C columns)
6670
- ss[j*T + p] = vs;
6671
- }
6643
+ ls += vs;
6672
6644
6673
- S[j] = S[j]*ms + warp_reduce_sum(ls);
6645
+ // the P matrix from the paper (Q rows, C columns)
6646
+ ss[j*T + p] = vs;
6674
6647
}
6648
+
6649
+ S[j] = S[j]*ms + warp_reduce_sum(ls);
6675
6650
}
6676
6651
6677
6652
smax = warp_reduce_max(smax);
@@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16(
6736
6711
// reduce the warps sequentially
6737
6712
for (int sg = 1; sg < num_warps; ++sg) {
6738
6713
half S = __float2half(0.0f);
6739
- half M = __float2half(-INFINITY) ;
6714
+ half M = CUDART_MIN_DENORM_FP16 ;
6740
6715
6741
6716
__syncthreads();
6742
6717
@@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16(
6762
6737
6763
6738
M = __hmax(M0, M1);
6764
6739
6765
- const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
6766
- const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
6740
+ const half ms0 = hexp(M0 - M);
6741
+ const half ms1 = hexp(M1 - M);
6767
6742
6768
6743
S = S0*ms0 + S1*ms1;
6769
6744
0 commit comments