From 674d5ac72d8d7233ee6df57f8f521d7445e3778c Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 3 Feb 2024 11:11:17 +0100 Subject: [PATCH] =?UTF-8?q?unroll=202=20loops,=20int64=5Ft=20->=20int,=203?= =?UTF-8?q?09=20=C2=B5s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ggml-cuda.cu | 82 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c6605fe95ee8c..b811aefe8d63b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6467,10 +6467,22 @@ static __global__ void flash_attn_ext_f16( half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory - for (int64_t j = warp_id; j < Q; j += num_warps) { +#pragma unroll + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - for (int64_t i = lane_id; i < D2; i += NW) { +#pragma unroll + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + if (iq1 + j < ne01) { sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { @@ -6482,15 +6494,15 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(zr, 0.0); // zero out lo - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(lo[j][i], 0.0); } } // zero out shared memory SH - for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = lane_id; i < SH; i += NW) { + for (int j = 0; j < Q; ++j) { + for (int i = lane_id; i < SH; i += NW) { ss[j*T + i] = 0.0; } } @@ -6531,8 +6543,8 @@ static __global__ void flash_attn_ext_f16( // load the queries from shared memory into local memory half16x16_a mq[Q16][D16]; - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } @@ -6549,28 +6561,28 @@ static __global__ void flash_attn_ext_f16( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { + for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { for (int cc = 0; cc < C/16; ++cc) { half16x16_acc mqk[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::fill_fragment(mqk[j], 0); } const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { half16x16_bT mk; // transposed key nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); } } // mqk = mqk*scale + mask - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mqka; half16x16_acc mm; if(mp) { @@ -6592,8 +6604,8 @@ static __global__ void flash_attn_ext_f16( // online softmax if (C == 32) { - for (int64_t j = 0; j < Q; ++j) { - const int64_t p = lane_id; + for (int j = 0; j < Q; ++j) { + const int p = lane_id; const half m = M[j]; const half s = ss[j*T + p]; @@ -6615,10 +6627,10 @@ static __global__ void flash_attn_ext_f16( ss[j*T + p] = vs; } } else { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half m = M[j]; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; smax = __hmax(smax, s); @@ -6638,7 +6650,7 @@ static __global__ void flash_attn_ext_f16( // local sum half ls = 0.0f; - for (int64_t p = lane_id; p < C; p += NW) { + for (int p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]); @@ -6659,13 +6671,13 @@ static __global__ void flash_attn_ext_f16( } // O = diag(ms)*O - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a mm; half16x16_b lob; nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { // convert accumulator to matrix_b nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); @@ -6684,17 +6696,17 @@ static __global__ void flash_attn_ext_f16( const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); half16x16_b mk[D16]; - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); } half16x16_a mv[Q16]; - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); } - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } @@ -6703,7 +6715,7 @@ static __global__ void flash_attn_ext_f16( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { if (lane_id == 0) { ss[j*T + 0] = S[j]; ss[j*T + 1] = M[j]; @@ -6712,7 +6724,7 @@ static __global__ void flash_attn_ext_f16( } // reduce the warps sequentially - for (int64_t sg = 1; sg < num_warps; ++sg) { + for (int sg = 1; sg < num_warps; ++sg) { half S = __float2half(0.0f); half M = __float2half(-INFINITY); @@ -6720,8 +6732,8 @@ static __global__ void flash_attn_ext_f16( // each simdgroup stores its output to shared memory, reusing sq if (warp_id == sg) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6731,7 +6743,7 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { - for (int64_t j = 0; j < Q; ++j) { + for (int j = 0; j < Q; ++j) { const half S0 = ss[j*T + 0]; const half S1 = ss[j*T + sg*SH + 0]; @@ -6755,7 +6767,7 @@ static __global__ void flash_attn_ext_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (int64_t j = 0; j < Q16; ++j) { + for (int j = 0; j < Q16; ++j) { half16x16_a ms0; half16x16_a ms1; half16x16_b t; @@ -6764,7 +6776,7 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); - for (int64_t i = 0; i < D16; ++i) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); @@ -6781,8 +6793,8 @@ static __global__ void flash_attn_ext_f16( // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int64_t j = 0; j < Q16; ++j) { - for (int64_t i = 0; i < D16; ++i) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); } } @@ -6790,10 +6802,10 @@ static __global__ void flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (warp_id == 0) { - for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { const half S = ss[j*T + 0]; - for (int64_t i = lane_id; i < D; i += NW) { + for (int i = lane_id; i < D; i += NW) { dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } }