From e585cba7fd35fb8acb3ac30bc055384195cd0fbd Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 31 Jul 2025 12:26:06 +0000 Subject: [PATCH 1/6] add F16/F16 fa support --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/ggml-opencl.cpp | 245 +++++++++++++ .../src/ggml-opencl/kernels/flash_attn_f16.cl | 333 +++++++++++++++++ .../src/ggml-opencl/kernels/flash_attn_f32.cl | 333 +++++++++++++++++ .../ggml-opencl/kernels/flash_attn_f32_f16.cl | 336 ++++++++++++++++++ 5 files changed, 1250 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 3adea83615437..d6fad9e966d2c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -109,6 +109,9 @@ set(GGML_OPENCL_KERNELS mul_mat_f16_f32 conv2d conv2d_f16_f32 + flash_attn_f32_f16 + flash_attn_f16 + flash_attn_f32 ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 984d35a2ecf76..064f140074c74 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -420,6 +421,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; + std::map, cl_kernel> kernels_flash_attn_f16; + std::map, cl_kernel> kernels_flash_attn_f16_q1; + std::map, cl_kernel> kernels_flash_attn_f32; + std::map, cl_kernel> kernels_flash_attn_f32_q1; + std::map, cl_kernel> kernels_flash_attn_f32_f16; + std::map, cl_kernel> kernels_flash_attn_f32_f16_q1; + std::map, int> kernels_flash_attn_bm; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_set_rows_f32, kernel_set_rows_f16; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; @@ -1263,6 +1271,75 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // flash_attn + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + GGML_ASSERT(k_f16 != NULL && k_f16_q1 != NULL); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + GGML_ASSERT(k_f32 != NULL && k_f32_q1 != NULL); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + GGML_ASSERT(k_f32_f16 != NULL && k_f32_f16_q1 != NULL); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + } + GGML_LOG_CONT("."); + } + } + // argsort { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2553,6 +2630,41 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM_ROWS: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * q = op->src[0]; + const ggml_tensor * k = op->src[1]; + const ggml_tensor * v = op->src[2]; + + const int dk = q->ne[0]; + const int dv = v->ne[0]; + + const struct { int dk; int dv; } supported_dims[] = { + { 64, 64}, { 80, 80}, { 96, 96}, + {112, 112}, {128, 128}, {192, 128}, + {192, 192}, {256, 256}, + }; + + bool dims_supported = false; + for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { + if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { + dims_supported = true; + break; + } + } + if (!dims_supported) { + return false; + } + + const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && + v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; + const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; + + return is_f32_f32 || is_f16_f16 || is_f32_f16; + } default: return false; } @@ -5193,6 +5305,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); } +static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + GGML_ASSERT(q->extra); + GGML_ASSERT(k->extra); + GGML_ASSERT(v->extra); + GGML_ASSERT(dst->extra); + if (mask) { + GGML_ASSERT(mask->extra); + } + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + const int n_q = q->ne[1]; + const int n_kv = k->ne[1]; + const int d_head_q = q->ne[0]; + const int d_head_v = v->ne[0]; + const int n_head = q->ne[2]; + const int n_head_kv = k->ne[2]; + const int n_batch = q->ne[3]; + + cl_kernel kernel = NULL; + + const bool is_f16 = q->type == GGML_TYPE_F16; + const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; + const std::pair dk_dv = {d_head_q, d_head_v}; + + if (n_q == 1) { + if (is_mixed) { + kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); + } else { + kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); + } + } else { + if (is_mixed) { + kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); + } else { + kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); + } + } + GGML_ASSERT(kernel != NULL); + + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; + ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; + + cl_ulong offset_q = extra_q->offset + q->view_offs; + cl_ulong offset_k = extra_k->offset + k->view_offs; + cl_ulong offset_v = extra_v->offset + v->view_offs; + cl_ulong offset_o = extra_o->offset + dst->view_offs; + cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; + cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; + + const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; + const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; + const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; + const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; + const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; + const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; + const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; + const int mask_ne2 = mask ? mask->ne[2] : 0; + const int mask_ne3 = mask ? mask->ne[3] : 0; + + float scale, max_bias, logit_softcap; + const float * params = (const float *)dst->op_params; + scale = params[0]; + max_bias = params[1]; + logit_softcap = params[2]; + + const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); + + const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; + const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; + const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); + CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); + CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); + CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); + + if (n_q == 1) { + const size_t wg_size = 64; + size_t local_work_size[] = { wg_size, 1 }; + size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + } else { + const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); + const size_t wg_size = block_m; + size_t local_work_size[] = { wg_size, 1 }; + size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + } +} + static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -7239,6 +7478,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sum_rows; break; + case GGML_OP_FLASH_ATTN_EXT: + if (!any_on_device) { + return false; + } + ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor); + return true; default: return false; } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl new file mode 100644 index 0000000000000..1cec0c5004537 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -0,0 +1,333 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define ACC_TYPE float +#define ACC_TYPE4 float4 +#define DATA_TYPE half +#define DATA_TYPE4 half4 +#define CONVERT_ACC4(x) convert_float4(x) +#define CONVERT_DATA4(x) convert_half4(x) + +#define DK_VEC (DK/4) +#define DV_VEC (DV/4) +#define WG_SIZE (BLOCK_M) +#define Q1_WG_SIZE 64 + +__kernel void flash_attn_f16( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int block_q_idx = get_group_id(0); + const int head_batch_idx = get_global_id(1); + + const int my_query_row = block_q_idx * BLOCK_M + tid; + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + if (my_query_row < n_q) { + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; + const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_ACC4(q_ptr[i]); + } + } + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = (ACC_TYPE4)(0.0f); + } + ACC_TYPE m_i = -INFINITY; + ACC_TYPE l_i = 0.0f; + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; + __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; + + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) { + const int row = i / DK_VEC; + const int col = i % DK_VEC; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col]; + } + } + for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) { + const int row = i / DV_VEC; + const int col = i % DV_VEC; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col]; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if (my_query_row >= n_q) { + continue; + } + + for (int j = 0; j < BLOCK_N; j += 2) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + } + ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + + if (is_causal) { + if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; + if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + } + + if (k_row0 >= n_kv) score0 = -INFINITY; + if (k_row1 >= n_kv) score1 = -INFINITY; + + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + } + + if (logit_softcap > 0.0f) { + score0 = logit_softcap * tanh(score0 / logit_softcap); + score1 = logit_softcap * tanh(score1 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(score0, score1)); + const ACC_TYPE p0 = exp(score0 - m_new); + const ACC_TYPE p1 = exp(score1 - m_new); + const ACC_TYPE scale_prev = exp(m_i - m_new); + + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); + } + l_i = l_i * scale_prev + p0 + p1; + m_i = m_new; + } + } + + if (my_query_row < n_q) { + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); + if (l_i > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_i; + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); + } + } else { + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = (DATA_TYPE4)(0.0f); + } + } + } +} + +__kernel void flash_attn_f16_q1( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; + const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_ACC4(q_ptr[i]); + } + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + ACC_TYPE m_i = -INFINITY; + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + m_i = max(m_i, score); + } + + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_final = local_m[0]; + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + ACC_TYPE l_i = 0.0f; + + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); + const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + const ACC_TYPE p = exp(score - m_final); + l_i += p; + for (int i = 0; i < DV_VEC; i++) { + o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); + } + } + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; + global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); + const ACC_TYPE l_final = local_l[0]; + + if (l_final > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_final; + for (int i = 0; i < DV_VEC; i++) { + local_o_comp[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv); + } + } + } else if (tid == 0) { + for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); + } +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl new file mode 100644 index 0000000000000..ad6067a6e94ae --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -0,0 +1,333 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define ACC_TYPE float +#define ACC_TYPE4 float4 +#define DATA_TYPE float +#define DATA_TYPE4 float4 +#define CONVERT_ACC4(x) (x) +#define CONVERT_DATA4(x) (x) + +#define DK_VEC (DK/4) +#define DV_VEC (DV/4) +#define WG_SIZE (BLOCK_M) +#define Q1_WG_SIZE 64 + +__kernel void flash_attn_f32( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int block_q_idx = get_group_id(0); + const int head_batch_idx = get_global_id(1); + + const int my_query_row = block_q_idx * BLOCK_M + tid; + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + if (my_query_row < n_q) { + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; + const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_ACC4(q_ptr[i]); + } + } + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = (ACC_TYPE4)(0.0f); + } + ACC_TYPE m_i = -INFINITY; + ACC_TYPE l_i = 0.0f; + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; + __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; + + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) { + const int row = i / DK_VEC; + const int col = i % DK_VEC; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col]; + } + } + for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) { + const int row = i / DV_VEC; + const int col = i % DV_VEC; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col]; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if (my_query_row >= n_q) { + continue; + } + + for (int j = 0; j < BLOCK_N; j += 2) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + } + ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + + if (is_causal) { + if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; + if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + } + + if (k_row0 >= n_kv) score0 = -INFINITY; + if (k_row1 >= n_kv) score1 = -INFINITY; + + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + } + + if (logit_softcap > 0.0f) { + score0 = logit_softcap * tanh(score0 / logit_softcap); + score1 = logit_softcap * tanh(score1 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(score0, score1)); + const ACC_TYPE p0 = exp(score0 - m_new); + const ACC_TYPE p1 = exp(score1 - m_new); + const ACC_TYPE scale_prev = exp(m_i - m_new); + + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); + } + l_i = l_i * scale_prev + p0 + p1; + m_i = m_new; + } + } + + if (my_query_row < n_q) { + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); + if (l_i > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_i; + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); + } + } else { + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = (DATA_TYPE4)(0.0f); + } + } + } +} + +__kernel void flash_attn_f32_q1( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; + const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_ACC4(q_ptr[i]); + } + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + ACC_TYPE m_i = -INFINITY; + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + m_i = max(m_i, score); + } + + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_final = local_m[0]; + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + ACC_TYPE l_i = 0.0f; + + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); + const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + const ACC_TYPE p = exp(score - m_final); + l_i += p; + for (int i = 0; i < DV_VEC; i++) { + o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); + } + } + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; + global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); + const ACC_TYPE l_final = local_l[0]; + + if (l_final > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_final; + for (int i = 0; i < DV_VEC; i++) { + local_o_comp[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv); + } + } + } else if (tid == 0) { + for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); + } +} diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl new file mode 100644 index 0000000000000..dd68e1fc56e80 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -0,0 +1,336 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define ACC_TYPE float +#define ACC_TYPE4 float4 +#define Q_DATA_TYPE4 float4 +#define KV_DATA_TYPE4 half4 +#define O_DATA_TYPE4 float4 +#define MASK_DATA_TYPE half +#define CONVERT_Q_ACC4(x) (x) +#define CONVERT_KV_ACC4(x) convert_float4(x) +#define CONVERT_O_DATA4(x) (x) + +#define DK_VEC (DK/4) +#define DV_VEC (DV/4) +#define WG_SIZE (BLOCK_M) +#define Q1_WG_SIZE 64 + +__kernel void flash_attn_f32_f16( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int block_q_idx = get_group_id(0); + const int head_batch_idx = get_global_id(1); + + const int my_query_row = block_q_idx * BLOCK_M + tid; + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + if (my_query_row < n_q) { + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; + const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + } + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = (ACC_TYPE4)(0.0f); + } + ACC_TYPE m_i = -INFINITY; + ACC_TYPE l_i = 0.0f; + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; + __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; + + for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) { + for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) { + const int row = i / DK_VEC; + const int col = i % DK_VEC; + const int k_row_idx = k_start + row; + if (k_row_idx < n_kv) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1; + l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col]; + } + } + for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) { + const int row = i / DV_VEC; + const int col = i % DV_VEC; + const int v_row_idx = k_start + row; + if (v_row_idx < n_kv) { + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1; + l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col]; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + if (my_query_row >= n_q) { + continue; + } + + for (int j = 0; j < BLOCK_N; j += 2) { + const int k_row0 = k_start + j; + const int k_row1 = k_start + j + 1; + + ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); + ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc0 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); + } + ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; + ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; + + if (is_causal) { + if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY; + if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY; + } + + if (k_row0 >= n_kv) score0 = -INFINITY; + if (k_row1 >= n_kv) score1 = -INFINITY; + + if (mask_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1); + if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0]; + if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1]; + } + + if (logit_softcap > 0.0f) { + score0 = logit_softcap * tanh(score0 / logit_softcap); + score1 = logit_softcap * tanh(score1 / logit_softcap); + } + + const ACC_TYPE m_new = max(m_i, max(score0, score1)); + const ACC_TYPE p0 = exp(score0 - m_new); + const ACC_TYPE p1 = exp(score1 - m_new); + const ACC_TYPE scale_prev = exp(m_i - m_new); + + for (int i = 0; i < DV_VEC; ++i) { + o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]); + } + l_i = l_i * scale_prev + p0 + p1; + m_i = m_new; + } + } + + if (my_query_row < n_q) { + const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + if (l_i > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_i; + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv); + } + } else { + for (int i = 0; i < DV_VEC; ++i) { + o_row[i] = (O_DATA_TYPE4)(0.0f); + } + } + } +} + +__kernel void flash_attn_f32_f16_q1( + const global void * q_void, ulong q_offset, + const global void * k_void, ulong k_offset, + const global void * v_void, ulong v_offset, + global void * o_void, ulong o_offset, + const float scale, + const int n_q, + const int n_kv, + const int is_causal, + const int n_head, + const ulong q_nb1, const ulong q_nb2, const ulong q_nb3, + const ulong k_nb1, const ulong k_nb2, const ulong k_nb3, + const ulong v_nb1, const ulong v_nb2, const ulong v_nb3, + const ulong o_nb1, const ulong o_nb2, const ulong o_nb3, + const float max_bias, + const float m0, + const float m1, + const int n_head_log2, + const float logit_softcap, + const int n_head_kv, + const global void* mask_void, + const ulong mask_offset, + const ulong mask_nb1, + const ulong mask_nb2, + const ulong mask_nb3, + const int mask_ne2, + const int mask_ne3 +) { + const int tid = get_local_id(0); + const int head_batch_idx = get_global_id(1); + + const int batch_idx = head_batch_idx / n_head; + const int head_idx = head_batch_idx % n_head; + + const int gqa_ratio = n_head / n_head_kv; + const int head_kv_idx = head_idx / gqa_ratio; + + const global char* q_base = (const global char*)q_void + q_offset; + const global char* k_base = (const global char*)k_void + k_offset; + const global char* v_base = (const global char*)v_void + v_offset; + global char* o_base = (global char*)o_void + o_offset; + + const global char* mask_base = NULL; + if (mask_void != NULL) { + const int mask_head_idx = head_idx % mask_ne2; + const int mask_batch_idx = batch_idx % mask_ne3; + mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2; + } + + ACC_TYPE4 q_priv[DK_VEC]; + const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; + const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + for (int i = 0; i < DK_VEC; ++i) { + q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); + } + + float slope = 1.0f; + if (max_bias > 0.0f) { + int h = head_idx; + if (h < n_head_log2) { + slope = pow(m0, (float)(h + 1)); + } else { + slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); + } + } + + ACC_TYPE m_i = -INFINITY; + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + m_i = max(m_i, score); + } + + __local ACC_TYPE local_m[Q1_WG_SIZE]; + local_m[tid] = m_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); + barrier(CLK_LOCAL_MEM_FENCE); + } + const ACC_TYPE m_final = local_m[0]; + + ACC_TYPE4 o_acc[DV_VEC]; + for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); + ACC_TYPE l_i = 0.0f; + + for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { + const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; + const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1; + const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); + const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset); + ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + for (int k = 0; k < DK_VEC; k++) { + dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + } + ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; + if (mask_base != NULL) { + const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base); + score += slope * (ACC_TYPE)mask_ptr[k_idx]; + } + if (logit_softcap > 0.0f) { + score = logit_softcap * tanh(score / logit_softcap); + } + const ACC_TYPE p = exp(score - m_final); + l_i += p; + for (int i = 0; i < DV_VEC; i++) { + o_acc[i] = fma(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); + } + } + + __local ACC_TYPE local_l[Q1_WG_SIZE]; + __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; + local_l[tid] = l_i; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_l[tid] += local_l[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + + const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; + global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); + const ACC_TYPE l_final = local_l[0]; + + if (l_final > 0.0f) { + const ACC_TYPE l_inv = 1.0f / l_final; + for (int i = 0; i < DV_VEC; i++) { + local_o_comp[tid] = o_acc[i]; + barrier(CLK_LOCAL_MEM_FENCE); + for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; + barrier(CLK_LOCAL_MEM_FENCE); + } + if (tid == 0) { + o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv); + } + } + } else if (tid == 0) { + for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f); + } +} From d3f049bbe838bcac226d264a75fbbe1b6abc87c7 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 31 Jul 2025 14:02:53 +0000 Subject: [PATCH 2/6] fix kernel init --- ggml/src/ggml-opencl/ggml-opencl.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 064f140074c74..b5b0a37e31b5e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -428,6 +428,7 @@ struct ggml_backend_opencl_context { std::map, cl_kernel> kernels_flash_attn_f32_f16; std::map, cl_kernel> kernels_flash_attn_f32_f16_q1; std::map, int> kernels_flash_attn_bm; + std::map, int> kernels_flash_attn_bn; cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; cl_kernel kernel_set_rows_f32, kernel_set_rows_f16; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; @@ -1311,7 +1312,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_kernel k_f16, k_f16_q1; CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - GGML_ASSERT(k_f16 != NULL && k_f16_q1 != NULL); backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; CL_CHECK(clReleaseProgram(prog_f16)); @@ -1320,7 +1320,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_kernel k_f32, k_f32_q1; CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - GGML_ASSERT(k_f32 != NULL && k_f32_q1 != NULL); backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; CL_CHECK(clReleaseProgram(prog_f32)); @@ -1329,12 +1328,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_kernel k_f32_f16, k_f32_f16_q1; CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - GGML_ASSERT(k_f32_f16 != NULL && k_f32_f16_q1 != NULL); backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; CL_CHECK(clReleaseProgram(prog_f32_f16)); backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; } GGML_LOG_CONT("."); } From 65910d30b2e6f7e5ecc4fafa6625d9d0dec955e0 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 1 Aug 2025 12:05:26 +0000 Subject: [PATCH 3/6] use mad instead of fma --- ggml/src/ggml-opencl/kernels/flash_attn_f16.cl | 10 +++++----- ggml/src/ggml-opencl/kernels/flash_attn_f32.cl | 10 +++++----- ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl index 1cec0c5004537..de045a11c883e 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -125,8 +125,8 @@ __kernel void flash_attn_f16( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); } ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; @@ -251,7 +251,7 @@ __kernel void flash_attn_f16_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -284,7 +284,7 @@ __kernel void flash_attn_f16_q1( const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -297,7 +297,7 @@ __kernel void flash_attn_f16_q1( const ACC_TYPE p = exp(score - m_final); l_i += p; for (int i = 0; i < DV_VEC; i++) { - o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); + o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index ad6067a6e94ae..866088068eebf 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -125,8 +125,8 @@ __kernel void flash_attn_f32( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = fma(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = fma(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); } ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; @@ -251,7 +251,7 @@ __kernel void flash_attn_f32_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -284,7 +284,7 @@ __kernel void flash_attn_f32_q1( const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -297,7 +297,7 @@ __kernel void flash_attn_f32_q1( const ACC_TYPE p = exp(score - m_final); l_i += p; for (int i = 0; i < DV_VEC; i++) { - o_acc[i] = fma(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); + o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl index dd68e1fc56e80..273dc05d08e80 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -128,8 +128,8 @@ __kernel void flash_attn_f32_f16( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc0 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); - dot_acc1 = fma(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); + dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); + dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); } ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale; ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale; @@ -254,7 +254,7 @@ __kernel void flash_attn_f32_f16_q1( const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -287,7 +287,7 @@ __kernel void flash_attn_f32_f16_q1( const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); for (int k = 0; k < DK_VEC; k++) { - dot_acc = fma(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); + dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale; if (mask_base != NULL) { @@ -300,7 +300,7 @@ __kernel void flash_attn_f32_f16_q1( const ACC_TYPE p = exp(score - m_final); l_i += p; for (int i = 0; i < DV_VEC; i++) { - o_acc[i] = fma(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); + o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); } } From 1b06404176eeb92973bb44786f1c6c794c582808 Mon Sep 17 00:00:00 2001 From: rmatif Date: Thu, 14 Aug 2025 19:12:17 +0000 Subject: [PATCH 4/6] use inline function --- .../src/ggml-opencl/kernels/flash_attn_f16.cl | 31 ++++++++----------- .../src/ggml-opencl/kernels/flash_attn_f32.cl | 31 ++++++++----------- .../ggml-opencl/kernels/flash_attn_f32_f16.cl | 31 ++++++++----------- 3 files changed, 39 insertions(+), 54 deletions(-) diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl index de045a11c883e..d015cf742ee90 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -12,6 +12,17 @@ #define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +inline float get_alibi_slope( + const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return pow(base, exph); +} __kernel void flash_attn_f16( const global void * q_void, ulong q_offset, const global void * k_void, ulong k_offset, @@ -80,15 +91,7 @@ __kernel void flash_attn_f16( ACC_TYPE m_i = -INFINITY; ACC_TYPE l_i = 0.0f; - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; @@ -235,15 +238,7 @@ __kernel void flash_attn_f16_q1( q_priv[i] = CONVERT_ACC4(q_ptr[i]); } - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); ACC_TYPE m_i = -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index 866088068eebf..4a585b2aea130 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -12,6 +12,17 @@ #define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +inline float get_alibi_slope( + const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return pow(base, exph); +} __kernel void flash_attn_f32( const global void * q_void, ulong q_offset, const global void * k_void, ulong k_offset, @@ -80,15 +91,7 @@ __kernel void flash_attn_f32( ACC_TYPE m_i = -INFINITY; ACC_TYPE l_i = 0.0f; - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; @@ -235,15 +238,7 @@ __kernel void flash_attn_f32_q1( q_priv[i] = CONVERT_ACC4(q_ptr[i]); } - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); ACC_TYPE m_i = -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl index 273dc05d08e80..25297b9b89d31 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -15,6 +15,17 @@ #define WG_SIZE (BLOCK_M) #define Q1_WG_SIZE 64 +inline float get_alibi_slope( + const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return pow(base, exph); +} __kernel void flash_attn_f32_f16( const global void * q_void, ulong q_offset, const global void * k_void, ulong k_offset, @@ -83,15 +94,7 @@ __kernel void flash_attn_f32_f16( ACC_TYPE m_i = -INFINITY; ACC_TYPE l_i = 0.0f; - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC]; __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC]; @@ -238,15 +241,7 @@ __kernel void flash_attn_f32_f16_q1( q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); } - float slope = 1.0f; - if (max_bias > 0.0f) { - int h = head_idx; - if (h < n_head_log2) { - slope = pow(m0, (float)(h + 1)); - } else { - slope = pow(m1, (float)(2 * (h - n_head_log2) + 1)); - } - } + float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); ACC_TYPE m_i = -INFINITY; for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { From db7c56422044871e0f68a91140322da5a94b4517 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 15 Aug 2025 08:37:19 +0000 Subject: [PATCH 5/6] mark FA with sinks as unsupported for now --- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b5b0a37e31b5e..a018f2d00c701 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2631,6 +2631,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: { + if (op->src[4]) { + return false; + } + const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; From 8c4025fd50061dd081c48da5900f03c650289bb1 Mon Sep 17 00:00:00 2001 From: rmatif Date: Fri, 15 Aug 2025 09:46:13 +0000 Subject: [PATCH 6/6] add pragma unroll to loops --- ggml/src/ggml-opencl/kernels/flash_attn_f16.cl | 15 +++++++++++++++ ggml/src/ggml-opencl/kernels/flash_attn_f32.cl | 15 +++++++++++++++ .../src/ggml-opencl/kernels/flash_attn_f32_f16.cl | 15 +++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl index d015cf742ee90..fea06867e1020 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl @@ -79,12 +79,14 @@ __kernel void flash_attn_f16( if (my_query_row < n_q) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } } ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } @@ -127,6 +129,7 @@ __kernel void flash_attn_f16( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); @@ -158,6 +161,7 @@ __kernel void flash_attn_f16( const ACC_TYPE p1 = exp(score1 - m_new); const ACC_TYPE scale_prev = exp(m_i - m_new); + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); } @@ -171,10 +175,12 @@ __kernel void flash_attn_f16( global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); } } else { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (DATA_TYPE4)(0.0f); } @@ -234,6 +240,7 @@ __kernel void flash_attn_f16_q1( ACC_TYPE4 q_priv[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } @@ -245,6 +252,7 @@ __kernel void flash_attn_f16_q1( const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -262,6 +270,7 @@ __kernel void flash_attn_f16_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -269,6 +278,7 @@ __kernel void flash_attn_f16_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -278,6 +288,7 @@ __kernel void flash_attn_f16_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -291,6 +302,7 @@ __kernel void flash_attn_f16_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; + #pragma unroll for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } @@ -300,6 +312,7 @@ __kernel void flash_attn_f16_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -314,6 +327,7 @@ __kernel void flash_attn_f16_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -323,6 +337,7 @@ __kernel void flash_attn_f16_q1( } } } else if (tid == 0) { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl index 4a585b2aea130..2d657327d6460 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl @@ -79,12 +79,14 @@ __kernel void flash_attn_f32( if (my_query_row < n_q) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } } ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } @@ -127,6 +129,7 @@ __kernel void flash_attn_f32( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0); dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1); @@ -158,6 +161,7 @@ __kernel void flash_attn_f32( const ACC_TYPE p1 = exp(score1 - m_new); const ACC_TYPE scale_prev = exp(m_i - m_new); + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]); } @@ -171,10 +175,12 @@ __kernel void flash_attn_f32( global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv); } } else { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (DATA_TYPE4)(0.0f); } @@ -234,6 +240,7 @@ __kernel void flash_attn_f32_q1( ACC_TYPE4 q_priv[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_ACC4(q_ptr[i]); } @@ -245,6 +252,7 @@ __kernel void flash_attn_f32_q1( const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -262,6 +270,7 @@ __kernel void flash_attn_f32_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -269,6 +278,7 @@ __kernel void flash_attn_f32_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -278,6 +288,7 @@ __kernel void flash_attn_f32_q1( const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc); } @@ -291,6 +302,7 @@ __kernel void flash_attn_f32_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; + #pragma unroll for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]); } @@ -300,6 +312,7 @@ __kernel void flash_attn_f32_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -314,6 +327,7 @@ __kernel void flash_attn_f32_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -323,6 +337,7 @@ __kernel void flash_attn_f32_q1( } } } else if (tid == 0) { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f); } } diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl index 25297b9b89d31..7067bd2591fa7 100644 --- a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl @@ -82,12 +82,14 @@ __kernel void flash_attn_f32_f16( if (my_query_row < n_q) { const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1; const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); } } ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = (ACC_TYPE4)(0.0f); } @@ -130,6 +132,7 @@ __kernel void flash_attn_f32_f16( ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f); ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0); dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1); @@ -161,6 +164,7 @@ __kernel void flash_attn_f32_f16( const ACC_TYPE p1 = exp(score1 - m_new); const ACC_TYPE scale_prev = exp(m_i - m_new); + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]); } @@ -174,10 +178,12 @@ __kernel void flash_attn_f32_f16( global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); if (l_i > 0.0f) { const ACC_TYPE l_inv = 1.0f / l_i; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv); } } else { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) { o_row[i] = (O_DATA_TYPE4)(0.0f); } @@ -237,6 +243,7 @@ __kernel void flash_attn_f32_f16_q1( ACC_TYPE4 q_priv[DK_VEC]; const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2; const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset); + #pragma unroll for (int i = 0; i < DK_VEC; ++i) { q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]); } @@ -248,6 +255,7 @@ __kernel void flash_attn_f32_f16_q1( const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } @@ -265,6 +273,7 @@ __kernel void flash_attn_f32_f16_q1( __local ACC_TYPE local_m[Q1_WG_SIZE]; local_m[tid] = m_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]); barrier(CLK_LOCAL_MEM_FENCE); @@ -272,6 +281,7 @@ __kernel void flash_attn_f32_f16_q1( const ACC_TYPE m_final = local_m[0]; ACC_TYPE4 o_acc[DV_VEC]; + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f); ACC_TYPE l_i = 0.0f; @@ -281,6 +291,7 @@ __kernel void flash_attn_f32_f16_q1( const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset); ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f); + #pragma unroll for (int k = 0; k < DK_VEC; k++) { dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc); } @@ -294,6 +305,7 @@ __kernel void flash_attn_f32_f16_q1( } const ACC_TYPE p = exp(score - m_final); l_i += p; + #pragma unroll for (int i = 0; i < DV_VEC; i++) { o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]); } @@ -303,6 +315,7 @@ __kernel void flash_attn_f32_f16_q1( __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE]; local_l[tid] = l_i; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_l[tid] += local_l[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -317,6 +330,7 @@ __kernel void flash_attn_f32_f16_q1( for (int i = 0; i < DV_VEC; i++) { local_o_comp[tid] = o_acc[i]; barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) { if (tid < s) local_o_comp[tid] += local_o_comp[tid + s]; barrier(CLK_LOCAL_MEM_FENCE); @@ -326,6 +340,7 @@ __kernel void flash_attn_f32_f16_q1( } } } else if (tid == 0) { + #pragma unroll for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f); } }