diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 86afb01338f73..e7bf95bd1d616 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6455,6 +6455,8 @@ static __global__ void flash_attn_ext_f16( half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + + half16x16_acc zr; half16x16_acc lo[Q16][D16]; // load heads from Q to shared memory @@ -6470,6 +6472,8 @@ static __global__ void flash_attn_ext_f16( } } + nvcuda::wmma::fill_fragment(zr, 0.0); + // zero out lo for (int64_t j = 0; j < Q16; ++j) { for (int64_t i = 0; i < D16; ++i) { @@ -6487,12 +6491,12 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); { - float S[Q]; - float M[Q]; + half S[Q]; + half M[Q]; for(int i = 0; i < Q; i++) { - S[i] = 0.0f; - M[i] = -INFINITY; + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); } // assume K and V are same shape @@ -6526,11 +6530,16 @@ static __global__ void flash_attn_ext_f16( } } - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - // pointer to the mask const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { @@ -6555,95 +6564,109 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - // TODO: process mask - mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; - } + half16x16_a mqka; + half16x16_acc mm; + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mm); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } } } // used to detect blocks full of -INF - float smax = -INFINITY; + half smax = __float2half(-INFINITY); // online softmax if (C == 32) { for (int64_t j = 0; j < Q; ++j) { const int64_t p = lane_id; - const float m = M[j]; - const float s = __half2float(ss[j*T + p]); + const half m = M[j]; + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = warp_reduce_max(__hmax(smax, s)); + M[j] = warp_reduce_max(__hmax(M[j], s)); - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); // create a QxQ diagonal matrix for rescaling the output if (p == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } } else { for (int64_t j = 0; j < Q; ++j) { - const float m = M[j]; + const half m = M[j]; for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - smax = warp_reduce_max(max(smax, s)); - M[j] = warp_reduce_max(max(M[j], s)); + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); - S[j] = S[j]*ms; + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { - ss[j*T + C + j] = __float2half(ms); + ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0f; + for (int64_t p = lane_id; p < C; p += NW) { - const float s = __half2float(ss[j*T + p]); + const half s = ss[j*T + p]; - const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); - S[j] = S[j] + warp_reduce_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) - ss[j*T + p] = __float2half(vs); + ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } } // skip -INF blocks - if (smax == -INFINITY) { + if (__hisinf(smax)) { continue; } // O = diag(ms)*O for (int64_t j = 0; j < Q16; ++j) { - // half16x16_a mm; - // half16x16_b zro; + half16x16_a mm; + half16x16_b lob; - // nvcuda::wmma::fill_fragment(zro, 0.0); - // nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); for (int64_t i = 0; i < D16; ++i) { - //nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); - for (uint32_t k = 0; k < 16*16; k++) { - half tmp = ss[(16*j + k%16)*T + C + 16*j + k%16]; - lo[j][i].x[k] = tmp * lo[j][i].x[k]; - } + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]); } + + // restore zeros + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); } // O = O + (Q*K^T)*V @@ -6651,15 +6674,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + half16x16_b mk[D16]; for (int64_t i = 0; i < D16; ++i) { - half16x16_b mk; - nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } - for (int64_t j = 0; j < Q16; ++j) { - half16x16_a mv; - nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T); + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } - nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]); + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } } @@ -6669,16 +6696,16 @@ static __global__ void flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (lane_id == 0) { - ss[j*T + 0] = __float2half(S[j]); - ss[j*T + 1] = __float2half(M[j]); + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; } } } // reduce the warps sequentially for (int64_t sg = 1; sg < num_warps; ++sg) { - float S = 0.0f; - float M = -INFINITY; + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); __syncthreads(); @@ -6696,25 +6723,25 @@ static __global__ void flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (warp_id == 0) { for (int64_t j = 0; j < Q; ++j) { - const float S0 = __half2float(ss[j*T + 0]); - const float S1 = __half2float(ss[j*T + sg*SH + 0]); + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; - const float M0 = __half2float(ss[j*T + 1]); - const float M1 = __half2float(ss[j*T + sg*SH + 1]); + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; - M = max(M0, M1); + M = __hmax(M0, M1); - const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M); - const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M); + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; if (lane_id == 0) { - ss[j*T + 0] = __float2half(S); - ss[j*T + 1] = __float2half(M); + ss[j*T + 0] = S; + ss[j*T + 1] = M; - ss[j*T + C + j ] = __float2half(ms0); - ss[j*T + C + j + sg*SH] = __float2half(ms1); + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; } } @@ -6732,10 +6759,11 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(t2, 0.0); nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); nvcuda::wmma::mma_sync(t2, ms1, t, t2); - // store temporally 'lo' data - nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); - // load 'lo' data into t - nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } @@ -6751,15 +6779,13 @@ static __global__ void flash_attn_ext_f16( } } - // float2 * dst2 = (float2 *) dst; - // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = __half2float(ss[j*T + 0]); + const half S = ss[j*T + 0]; for (int64_t i = lane_id; i < D; i += NW) { - dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S; + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); } } } @@ -9618,7 +9644,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); @@ -10897,8 +10923,8 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); - GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -10912,19 +10938,25 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nqpb = 16; // queries per block - const int ncpw = 32; // cache values per warp (does not work for other values) - // const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; - const int nwarps = 1; +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); - int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + switch (Q->ne[0]) { case 16: - flash_attn_ext_f16<16, 16, 32> + flash_attn_ext_f16<16, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10941,7 +10973,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 64: - flash_attn_ext_f16<64, 16, 32> + flash_attn_ext_f16<64, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10958,7 +10990,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 80: - flash_attn_ext_f16<80, 16, 32> + flash_attn_ext_f16<80, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key @@ -10975,7 +11007,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * ); break; case 128: - flash_attn_ext_f16<128, 16, 32> + flash_attn_ext_f16<128, NQPB, NCPW> <<>> ( (const char *) src0_extra->data_device[g_main_device], // Query (const char *) src1_extra->data_device[g_main_device], // Key diff --git a/ggml.c b/ggml.c index 59a4c05a12ffe..ebd9c6b341080 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; diff --git a/llama.cpp b/llama.cpp index fe25839669efc..2330efff57bd3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b1b30b91c9c6b..e23384eee27c2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -572,9 +572,19 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 1 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + n_runs = 1000; + int n_nodes = gf->n_nodes; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); -#if 0 - for (int hs : { 64, 80, 96, 112, 128, 256, }) { +#if 1 + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { @@ -2214,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int hs : { 128, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, 512 }) { + for (int nb : { 1, 2, 4, 7, 8, 15, 16, 512 }) { test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); }