Skip to content

Commit c51f27c

Browse files
committed
cuda : avoid __hisinf branches
1 parent 92472ea commit c51f27c

File tree

1 file changed

+29
-54
lines changed

1 file changed

+29
-54
lines changed

ggml-cuda.cu

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16(
65136513
half S[Q];
65146514
half M[Q];
65156515

6516-
for(int i = 0; i < Q; i++) {
6516+
for (int i = 0; i < Q; ++i) {
65176517
S[i] = __float2half(0.0f);
6518-
M[i] = __float2half(-INFINITY);
6518+
M[i] = CUDART_MIN_DENORM_FP16;
65196519
}
65206520

65216521
// assume K and V are same shape
@@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16(
66096609
half smax = __float2half(-INFINITY);
66106610

66116611
// online softmax
6612-
if (C == 32) {
6613-
for (int j = 0; j < Q; ++j) {
6614-
const int p = lane_id;
6615-
6616-
const half m = M[j];
6617-
const half s = ss[j*T + p];
6618-
6619-
smax = warp_reduce_max(__hmax(smax, s));
6620-
M[j] = warp_reduce_max(__hmax(M[j], s));
6621-
6622-
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6623-
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
6612+
for (int j = 0; j < Q; ++j) {
6613+
const half m = M[j];
66246614

6625-
S[j] = S[j]*ms + warp_reduce_sum(vs);
6615+
for (int p0 = 0; p0 < C; p0 += NW) {
6616+
const int p = p0 + lane_id;
66266617

6627-
// create a QxQ diagonal matrix for rescaling the output
6628-
if (p == j) {
6629-
ss[j*T + C + j] = ms;
6630-
}
6618+
const half s = ss[j*T + p];
66316619

6632-
// the P matrix from the paper (Q rows, C columns)
6633-
ss[j*T + p] = vs;
6620+
smax = __hmax(smax, s);
6621+
M[j] = __hmax(M[j], s);
66346622
}
6635-
} else {
6636-
for (int j = 0; j < Q; ++j) {
6637-
const half m = M[j];
6638-
6639-
for (int p0 = 0; p0 < C; p0 += NW) {
6640-
const int p = p0 + lane_id;
66416623

6642-
const half s = ss[j*T + p];
6624+
M[j] = warp_reduce_max(M[j]);
66436625

6644-
smax = __hmax(smax, s);
6645-
M[j] = __hmax(M[j], s);
6646-
}
6647-
6648-
M[j] = warp_reduce_max(M[j]);
6649-
6650-
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6626+
const half ms = hexp(m - M[j]);
66516627

6652-
// create a QxQ diagonal matrix for rescaling the output
6653-
if (lane_id == j) {
6654-
ss[j*T + C + j] = ms;
6655-
}
6656-
6657-
// local sum
6658-
half ls = 0.0f;
6628+
// create a QxQ diagonal matrix for rescaling the output
6629+
if (lane_id == j) {
6630+
ss[j*T + C + j] = ms;
6631+
}
66596632

6660-
for (int p0 = 0; p0 < C; p0 += NW) {
6661-
const int p = p0 + lane_id;
6633+
// local sum
6634+
half ls = 0.0f;
66626635

6663-
const half s = ss[j*T + p];
6636+
for (int p0 = 0; p0 < C; p0 += NW) {
6637+
const int p = p0 + lane_id;
66646638

6665-
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
6639+
const half s = ss[j*T + p];
66666640

6667-
ls += vs;
6641+
const half vs = hexp(s - M[j]);
66686642

6669-
// the P matrix from the paper (Q rows, C columns)
6670-
ss[j*T + p] = vs;
6671-
}
6643+
ls += vs;
66726644

6673-
S[j] = S[j]*ms + warp_reduce_sum(ls);
6645+
// the P matrix from the paper (Q rows, C columns)
6646+
ss[j*T + p] = vs;
66746647
}
6648+
6649+
S[j] = S[j]*ms + warp_reduce_sum(ls);
66756650
}
66766651

66776652
smax = warp_reduce_max(smax);
@@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16(
67366711
// reduce the warps sequentially
67376712
for (int sg = 1; sg < num_warps; ++sg) {
67386713
half S = __float2half(0.0f);
6739-
half M = __float2half(-INFINITY);
6714+
half M = CUDART_MIN_DENORM_FP16;
67406715

67416716
__syncthreads();
67426717

@@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16(
67626737

67636738
M = __hmax(M0, M1);
67646739

6765-
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
6766-
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
6740+
const half ms0 = hexp(M0 - M);
6741+
const half ms1 = hexp(M1 - M);
67676742

67686743
S = S0*ms0 + S1*ms1;
67696744

0 commit comments

Comments
 (0)