Skip to content

Commit 43f7156

Browse files
authored
Merge pull request #3 from ggerganov/flash-attn-cuda
cuda : fix flash_attn kernel to produce same results as CPU
2 parents fd878f7 + ac26f27 commit 43f7156

File tree

4 files changed

+129
-86
lines changed

4 files changed

+129
-86
lines changed

ggml-cuda.cu

Lines changed: 113 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16(
64556455
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
64566456
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
64576457
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
6458+
6459+
half16x16_acc zr;
64586460
half16x16_acc lo[Q16][D16];
64596461

64606462
// load heads from Q to shared memory
@@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16(
64706472
}
64716473
}
64726474

6475+
nvcuda::wmma::fill_fragment(zr, 0.0);
6476+
64736477
// zero out lo
64746478
for (int64_t j = 0; j < Q16; ++j) {
64756479
for (int64_t i = 0; i < D16; ++i) {
@@ -6487,12 +6491,12 @@ static __global__ void flash_attn_ext_f16(
64876491
__syncthreads();
64886492

64896493
{
6490-
float S[Q];
6491-
float M[Q];
6494+
half S[Q];
6495+
half M[Q];
64926496

64936497
for(int i = 0; i < Q; i++) {
6494-
S[i] = 0.0f;
6495-
M[i] = -INFINITY;
6498+
S[i] = __float2half(0.0f);
6499+
M[i] = __float2half(-INFINITY);
64966500
}
64976501

64986502
// assume K and V are same shape
@@ -6526,11 +6530,16 @@ static __global__ void flash_attn_ext_f16(
65266530
}
65276531
}
65286532

6529-
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6530-
65316533
// pointer to the mask
65326534
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
65336535

6536+
// prepare diagonal scale matrix
6537+
half16x16_b mscale;
6538+
for (int i = 0; i < 16; ++i) {
6539+
ss[i*T + i] = __float2half(scale);
6540+
}
6541+
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
6542+
65346543
// loop over the KV cache
65356544
// each simdgroup handles blocks of Q rows and C columns
65366545
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
@@ -6555,111 +6564,129 @@ static __global__ void flash_attn_ext_f16(
65556564

65566565
// mqk = mqk*scale + mask
65576566
for (int64_t j = 0; j < Q16; ++j) {
6558-
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
6559-
// TODO: process mask
6560-
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
6561-
}
6567+
half16x16_a mqka;
6568+
half16x16_acc mm;
6569+
6570+
// convert accumulator to matrix_a
6571+
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6572+
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
6573+
6574+
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
6575+
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm);
65626576
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
65636577
}
65646578
}
65656579
}
65666580

65676581
// used to detect blocks full of -INF
6568-
float smax = -INFINITY;
6582+
half smax = __float2half(-INFINITY);
65696583

65706584
// online softmax
65716585
if (C == 32) {
65726586
for (int64_t j = 0; j < Q; ++j) {
65736587
const int64_t p = lane_id;
65746588

6575-
const float m = M[j];
6576-
const float s = __half2float(ss[j*T + p]);
6589+
const half m = M[j];
6590+
const half s = ss[j*T + p];
65776591

6578-
smax = warp_reduce_max(max(smax, s));
6579-
M[j] = warp_reduce_max(max(M[j], s));
6592+
smax = warp_reduce_max(__hmax(smax, s));
6593+
M[j] = warp_reduce_max(__hmax(M[j], s));
65806594

6581-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6582-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6595+
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
6596+
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
65836597

65846598
S[j] = S[j]*ms + warp_reduce_sum(vs);
65856599

65866600
// create a QxQ diagonal matrix for rescaling the output
65876601
if (p == j) {
6588-
ss[j*T + C + j] = __float2half(ms);
6602+
ss[j*T + C + j] = ms;
65896603
}
65906604

65916605
// the P matrix from the paper (Q rows, C columns)
6592-
ss[j*T + p] = __float2half(vs);
6606+
ss[j*T + p] = vs;
65936607
}
65946608
} else {
65956609
for (int64_t j = 0; j < Q; ++j) {
6596-
const float m = M[j];
6610+
const half m = M[j];
65976611

65986612
for (int64_t p = lane_id; p < C; p += NW) {
6599-
const float s = __half2float(ss[j*T + p]);
6613+
const half s = ss[j*T + p];
66006614

6601-
smax = warp_reduce_max(max(smax, s));
6602-
M[j] = warp_reduce_max(max(M[j], s));
6615+
smax = __hmax(smax, s);
6616+
M[j] = __hmax(M[j], s);
66036617
}
66046618

6605-
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
6619+
smax = warp_reduce_max(smax);
6620+
M[j] = warp_reduce_max(M[j]);
66066621

6607-
S[j] = S[j]*ms;
6622+
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
66086623

66096624
// create a QxQ diagonal matrix for rescaling the output
66106625
if (lane_id == j) {
6611-
ss[j*T + C + j] = __float2half(ms);
6626+
ss[j*T + C + j] = ms;
66126627
}
66136628

6629+
// local sum
6630+
half ls = 0.0f;
6631+
66146632
for (int64_t p = lane_id; p < C; p += NW) {
6615-
const float s = __half2float(ss[j*T + p]);
6633+
const half s = ss[j*T + p];
66166634

6617-
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
6635+
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
66186636

6619-
S[j] = S[j] + warp_reduce_sum(vs);
6637+
ls += vs;
66206638

66216639
// the P matrix from the paper (Q rows, C columns)
6622-
ss[j*T + p] = __float2half(vs);
6640+
ss[j*T + p] = vs;
66236641
}
6642+
6643+
S[j] = S[j]*ms + warp_reduce_sum(ls);
66246644
}
66256645
}
66266646

66276647
// skip -INF blocks
6628-
if (smax == -INFINITY) {
6648+
if (__hisinf(smax)) {
66296649
continue;
66306650
}
66316651

66326652
// O = diag(ms)*O
66336653
for (int64_t j = 0; j < Q16; ++j) {
6634-
// half16x16_a mm;
6635-
// half16x16_b zro;
6654+
half16x16_a mm;
6655+
half16x16_b lob;
66366656

6637-
// nvcuda::wmma::fill_fragment(zro, 0.0);
6638-
// nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
6657+
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
66396658

66406659
for (int64_t i = 0; i < D16; ++i) {
6641-
//nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]);
6642-
for (uint32_t k = 0; k < 16*16; k++) {
6643-
half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16];
6644-
lo[j][i].x[k] = tmp * lo[j][i].x[k];
6645-
}
6660+
// convert accumulator to matrix_b
6661+
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6662+
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
6663+
6664+
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
6665+
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
66466666
}
6667+
6668+
// restore zeros
6669+
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
66476670
}
66486671

66496672
// O = O + (Q*K^T)*V
66506673
{
66516674
for (int cc = 0; cc < C/16; ++cc) {
66526675
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
66536676

6677+
half16x16_b mk[D16];
66546678
for (int64_t i = 0; i < D16; ++i) {
6655-
half16x16_b mk;
6656-
nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half));
6679+
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
6680+
}
66576681

6658-
for (int64_t j = 0; j < Q16; ++j) {
6659-
half16x16_a mv;
6660-
nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T);
6682+
half16x16_a mv[Q16];
6683+
for (int64_t j = 0; j < Q16; ++j) {
6684+
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
6685+
}
66616686

6662-
nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]);
6687+
for (int64_t j = 0; j < Q16; ++j) {
6688+
for (int64_t i = 0; i < D16; ++i) {
6689+
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
66636690
}
66646691
}
66656692
}
@@ -6669,16 +6696,16 @@ static __global__ void flash_attn_ext_f16(
66696696
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
66706697
for (int64_t j = 0; j < Q; ++j) {
66716698
if (lane_id == 0) {
6672-
ss[j*T + 0] = __float2half(S[j]);
6673-
ss[j*T + 1] = __float2half(M[j]);
6699+
ss[j*T + 0] = S[j];
6700+
ss[j*T + 1] = M[j];
66746701
}
66756702
}
66766703
}
66776704

66786705
// reduce the warps sequentially
66796706
for (int64_t sg = 1; sg < num_warps; ++sg) {
6680-
float S = 0.0f;
6681-
float M = -INFINITY;
6707+
half S = __float2half(0.0f);
6708+
half M = __float2half(-INFINITY);
66826709

66836710
__syncthreads();
66846711

@@ -6696,25 +6723,25 @@ static __global__ void flash_attn_ext_f16(
66966723
// the first simdgroup accumulates the results from the other simdgroups
66976724
if (warp_id == 0) {
66986725
for (int64_t j = 0; j < Q; ++j) {
6699-
const float S0 = __half2float(ss[j*T + 0]);
6700-
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
6726+
const half S0 = ss[j*T + 0];
6727+
const half S1 = ss[j*T + sg*SH + 0];
67016728

6702-
const float M0 = __half2float(ss[j*T + 1]);
6703-
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
6729+
const half M0 = ss[j*T + 1];
6730+
const half M1 = ss[j*T + sg*SH + 1];
67046731

6705-
M = max(M0, M1);
6732+
M = __hmax(M0, M1);
67066733

6707-
const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
6708-
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
6734+
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
6735+
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
67096736

67106737
S = S0*ms0 + S1*ms1;
67116738

67126739
if (lane_id == 0) {
6713-
ss[j*T + 0] = __float2half(S);
6714-
ss[j*T + 1] = __float2half(M);
6740+
ss[j*T + 0] = S;
6741+
ss[j*T + 1] = M;
67156742

6716-
ss[j*T + C + j ] = __float2half(ms0);
6717-
ss[j*T + C + j + sg*SH] = __float2half(ms1);
6743+
ss[j*T + C + j ] = ms0;
6744+
ss[j*T + C + j + sg*SH] = ms1;
67186745
}
67196746
}
67206747

@@ -6732,10 +6759,11 @@ static __global__ void flash_attn_ext_f16(
67326759
nvcuda::wmma::fill_fragment(t2, 0.0);
67336760
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
67346761
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
6735-
// store temporally 'lo' data
6736-
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
6737-
// load 'lo' data into t
6738-
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
6762+
6763+
// convert accumulator to matrix_b
6764+
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
6765+
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
6766+
67396767
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
67406768
}
67416769
}
@@ -6751,15 +6779,13 @@ static __global__ void flash_attn_ext_f16(
67516779
}
67526780
}
67536781

6754-
// float2 * dst2 = (float2 *) dst;
6755-
67566782
// final rescale with 1/S and store to global memory
67576783
if (warp_id == 0) {
67586784
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
6759-
const float S = __half2float(ss[j*T + 0]);
6785+
const half S = ss[j*T + 0];
67606786

67616787
for (int64_t i = lane_id; i < D; i += NW) {
6762-
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
6788+
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
67636789
}
67646790
}
67656791
}
@@ -9618,7 +9644,7 @@ static void ggml_cuda_op_soft_max(
96189644

96199645
const int64_t ne00 = src0->ne[0];
96209646
const int64_t nrows_x = ggml_nrows(src0);
9621-
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
9647+
const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded!
96229648

96239649
float scale = 1.0f;
96249650
memcpy(&scale, dst->op_params, sizeof(float));
@@ -10897,8 +10923,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1089710923

1089810924
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
1089910925
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
10900-
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
10901-
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
10926+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
10927+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
1090210928

1090310929
ggml_cuda_set_device(g_main_device);
1090410930
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -10912,19 +10938,25 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1091210938
float scale;
1091310939
memcpy(&scale, KQV->op_params, sizeof(float));
1091410940

10915-
const int nqpb = 16; // queries per block
10916-
const int ncpw = 32; // cache values per warp (does not work for other values)
10917-
// const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4;
10918-
const int nwarps = 1;
10941+
#define NQPB 16
10942+
#define NCPW 128
10943+
10944+
const int nqpb = NQPB; // queries per block
10945+
const int ncpw = NCPW; // cache values per warp (does not work for other values)
10946+
10947+
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
10948+
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
10949+
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;
1091910950

1092010951
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
1092110952
dim3 block_dim(32, nwarps, 1);
1092210953

10923-
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
10954+
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
10955+
1092410956
switch (Q->ne[0])
1092510957
{
1092610958
case 16:
10927-
flash_attn_ext_f16<16, 16, 32>
10959+
flash_attn_ext_f16<16, NQPB, NCPW>
1092810960
<<<blocks_num, block_dim, shmem, main_stream>>> (
1092910961
(const char *) src0_extra->data_device[g_main_device], // Query
1093010962
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10941,7 +10973,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1094110973
);
1094210974
break;
1094310975
case 64:
10944-
flash_attn_ext_f16<64, 16, 32>
10976+
flash_attn_ext_f16<64, NQPB, NCPW>
1094510977
<<<blocks_num, block_dim, shmem, main_stream>>> (
1094610978
(const char *) src0_extra->data_device[g_main_device], // Query
1094710979
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10958,7 +10990,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1095810990
);
1095910991
break;
1096010992
case 80:
10961-
flash_attn_ext_f16<80, 16, 32>
10993+
flash_attn_ext_f16<80, NQPB, NCPW>
1096210994
<<<blocks_num, block_dim, shmem, main_stream>>> (
1096310995
(const char *) src0_extra->data_device[g_main_device], // Query
1096410996
(const char *) src1_extra->data_device[g_main_device], // Key
@@ -10975,7 +11007,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1097511007
);
1097611008
break;
1097711009
case 128:
10978-
flash_attn_ext_f16<128, 16, 32>
11010+
flash_attn_ext_f16<128, NQPB, NCPW>
1097911011
<<<blocks_num, block_dim, shmem, main_stream>>> (
1098011012
(const char *) src0_extra->data_device[g_main_device], // Query
1098111013
(const char *) src1_extra->data_device[g_main_device], // Key

0 commit comments

Comments
 (0)