|  | 
| 25 | 25 | #include <vector> | 
| 26 | 26 | #include <string> | 
| 27 | 27 | #include <cmath> | 
|  | 28 | +#include <map> | 
| 28 | 29 | #include <memory> | 
| 29 | 30 | #include <charconv> | 
| 30 | 31 | #include <mutex> | 
| @@ -424,6 +425,14 @@ struct ggml_backend_opencl_context { | 
| 424 | 425 |     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; | 
| 425 | 426 |     cl_kernel kernel_soft_max, kernel_soft_max_4; | 
| 426 | 427 |     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; | 
|  | 428 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16; | 
|  | 429 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1; | 
|  | 430 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32; | 
|  | 431 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1; | 
|  | 432 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16; | 
|  | 433 | +    std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1; | 
|  | 434 | +    std::map<std::pair<int, int>, int>       kernels_flash_attn_bm; | 
|  | 435 | +    std::map<std::pair<int, int>, int>       kernels_flash_attn_bn; | 
| 427 | 436 |     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; | 
| 428 | 437 |     cl_kernel kernel_set_rows_f32, kernel_set_rows_f16; | 
| 429 | 438 |     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; | 
| @@ -1308,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve | 
| 1308 | 1317 |         GGML_LOG_CONT("."); | 
| 1309 | 1318 |     } | 
| 1310 | 1319 | 
 | 
|  | 1320 | +    // flash_attn | 
|  | 1321 | +    { | 
|  | 1322 | +        #ifdef GGML_OPENCL_EMBED_KERNELS | 
|  | 1323 | +                const std::string kernel_src_f16 { | 
|  | 1324 | +                    #include "flash_attn_f16.cl.h" | 
|  | 1325 | +                }; | 
|  | 1326 | +                const std::string kernel_src_f32 { | 
|  | 1327 | +                    #include "flash_attn_f32.cl.h" | 
|  | 1328 | +                }; | 
|  | 1329 | +                const std::string kernel_src_f32_f16 { | 
|  | 1330 | +                    #include "flash_attn_f32_f16.cl.h" | 
|  | 1331 | +                }; | 
|  | 1332 | +        #else | 
|  | 1333 | +                const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); | 
|  | 1334 | +                const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); | 
|  | 1335 | +                const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); | 
|  | 1336 | +        #endif | 
|  | 1337 | + | 
|  | 1338 | +        if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { | 
|  | 1339 | +            const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { | 
|  | 1340 | +                { 64,  64, 64, 64}, { 80,  80, 64, 32}, { 96,  96, 64, 32}, | 
|  | 1341 | +                {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, | 
|  | 1342 | +                {192, 192, 16, 16}, {256, 256, 16, 16}, | 
|  | 1343 | +            }; | 
|  | 1344 | + | 
|  | 1345 | +            for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { | 
|  | 1346 | +                const int dk = fa_dims[i].dk; | 
|  | 1347 | +                const int dv = fa_dims[i].dv; | 
|  | 1348 | +                const int bm = fa_dims[i].bm; | 
|  | 1349 | +                const int bn = fa_dims[i].bn; | 
|  | 1350 | +                std::string OPTS = compile_opts + | 
|  | 1351 | +                    " -D DK=" + std::to_string(dk) + | 
|  | 1352 | +                    " -D DV=" + std::to_string(dv) + | 
|  | 1353 | +                    " -D BLOCK_M=" + std::to_string(bm) + | 
|  | 1354 | +                    " -D BLOCK_N=" + std::to_string(bn); | 
|  | 1355 | + | 
|  | 1356 | +                cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); | 
|  | 1357 | +                cl_kernel k_f16, k_f16_q1; | 
|  | 1358 | +                CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); | 
|  | 1359 | +                CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); | 
|  | 1360 | +                backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; | 
|  | 1361 | +                backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; | 
|  | 1362 | +                CL_CHECK(clReleaseProgram(prog_f16)); | 
|  | 1363 | + | 
|  | 1364 | +                cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); | 
|  | 1365 | +                cl_kernel k_f32, k_f32_q1; | 
|  | 1366 | +                CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); | 
|  | 1367 | +                CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); | 
|  | 1368 | +                backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; | 
|  | 1369 | +                backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; | 
|  | 1370 | +                CL_CHECK(clReleaseProgram(prog_f32)); | 
|  | 1371 | + | 
|  | 1372 | +                cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); | 
|  | 1373 | +                cl_kernel k_f32_f16, k_f32_f16_q1; | 
|  | 1374 | +                CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); | 
|  | 1375 | +                CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); | 
|  | 1376 | +                backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; | 
|  | 1377 | +                backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; | 
|  | 1378 | +                CL_CHECK(clReleaseProgram(prog_f32_f16)); | 
|  | 1379 | + | 
|  | 1380 | +                backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; | 
|  | 1381 | +                backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; | 
|  | 1382 | +            } | 
|  | 1383 | +            GGML_LOG_CONT("."); | 
|  | 1384 | +        } | 
|  | 1385 | +    } | 
|  | 1386 | + | 
| 1311 | 1387 |     // argsort | 
| 1312 | 1388 |     { | 
| 1313 | 1389 | #ifdef GGML_OPENCL_EMBED_KERNELS | 
| @@ -2636,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te | 
| 2636 | 2712 |             return op->src[0]->type == GGML_TYPE_F32; | 
| 2637 | 2713 |         case GGML_OP_SUM_ROWS: | 
| 2638 | 2714 |             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); | 
|  | 2715 | +        case GGML_OP_FLASH_ATTN_EXT: | 
|  | 2716 | +            { | 
|  | 2717 | +                if (op->src[4]) { | 
|  | 2718 | +                    return false; | 
|  | 2719 | +                } | 
|  | 2720 | + | 
|  | 2721 | +                const ggml_tensor * q = op->src[0]; | 
|  | 2722 | +                const ggml_tensor * k = op->src[1]; | 
|  | 2723 | +                const ggml_tensor * v = op->src[2]; | 
|  | 2724 | + | 
|  | 2725 | +                const int dk = q->ne[0]; | 
|  | 2726 | +                const int dv = v->ne[0]; | 
|  | 2727 | + | 
|  | 2728 | +                const struct { int dk; int dv; } supported_dims[] = { | 
|  | 2729 | +                    { 64,  64}, { 80,  80}, { 96,  96}, | 
|  | 2730 | +                    {112, 112}, {128, 128}, {192, 128}, | 
|  | 2731 | +                    {192, 192}, {256, 256}, | 
|  | 2732 | +                }; | 
|  | 2733 | + | 
|  | 2734 | +                bool dims_supported = false; | 
|  | 2735 | +                for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { | 
|  | 2736 | +                    if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { | 
|  | 2737 | +                        dims_supported = true; | 
|  | 2738 | +                        break; | 
|  | 2739 | +                    } | 
|  | 2740 | +                } | 
|  | 2741 | +                if (!dims_supported) { | 
|  | 2742 | +                    return false; | 
|  | 2743 | +                } | 
|  | 2744 | + | 
|  | 2745 | +                const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && | 
|  | 2746 | +                                        v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; | 
|  | 2747 | +                const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && | 
|  | 2748 | +                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; | 
|  | 2749 | +                const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && | 
|  | 2750 | +                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; | 
|  | 2751 | + | 
|  | 2752 | +                return is_f32_f32 || is_f16_f16 || is_f32_f16; | 
|  | 2753 | +            } | 
| 2639 | 2754 |         default: | 
| 2640 | 2755 |             return false; | 
| 2641 | 2756 |     } | 
| @@ -5451,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor | 
| 5451 | 5566 |     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); | 
| 5452 | 5567 | } | 
| 5453 | 5568 | 
 | 
|  | 5569 | +static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { | 
|  | 5570 | +    const ggml_tensor * v = dst->src[2]; | 
|  | 5571 | +    const ggml_tensor * mask = dst->src[3]; | 
|  | 5572 | +    GGML_ASSERT(q->extra); | 
|  | 5573 | +    GGML_ASSERT(k->extra); | 
|  | 5574 | +    GGML_ASSERT(v->extra); | 
|  | 5575 | +    GGML_ASSERT(dst->extra); | 
|  | 5576 | +    if (mask) { | 
|  | 5577 | +        GGML_ASSERT(mask->extra); | 
|  | 5578 | +    } | 
|  | 5579 | + | 
|  | 5580 | +    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | 
|  | 5581 | + | 
|  | 5582 | +    const int n_q = q->ne[1]; | 
|  | 5583 | +    const int n_kv = k->ne[1]; | 
|  | 5584 | +    const int d_head_q = q->ne[0]; | 
|  | 5585 | +    const int d_head_v = v->ne[0]; | 
|  | 5586 | +    const int n_head = q->ne[2]; | 
|  | 5587 | +    const int n_head_kv = k->ne[2]; | 
|  | 5588 | +    const int n_batch = q->ne[3]; | 
|  | 5589 | + | 
|  | 5590 | +    cl_kernel kernel = NULL; | 
|  | 5591 | + | 
|  | 5592 | +    const bool is_f16 = q->type == GGML_TYPE_F16; | 
|  | 5593 | +    const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; | 
|  | 5594 | +    const std::pair<int, int> dk_dv = {d_head_q, d_head_v}; | 
|  | 5595 | + | 
|  | 5596 | +    if (n_q == 1) { | 
|  | 5597 | +        if (is_mixed) { | 
|  | 5598 | +            kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); | 
|  | 5599 | +        } else if (is_f16) { | 
|  | 5600 | +            kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); | 
|  | 5601 | +        } else { | 
|  | 5602 | +            kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); | 
|  | 5603 | +        } | 
|  | 5604 | +    } else { | 
|  | 5605 | +        if (is_mixed) { | 
|  | 5606 | +            kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); | 
|  | 5607 | +        } else if (is_f16) { | 
|  | 5608 | +            kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); | 
|  | 5609 | +        } else { | 
|  | 5610 | +            kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); | 
|  | 5611 | +        } | 
|  | 5612 | +    } | 
|  | 5613 | +    GGML_ASSERT(kernel != NULL); | 
|  | 5614 | + | 
|  | 5615 | +    ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; | 
|  | 5616 | +    ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; | 
|  | 5617 | +    ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; | 
|  | 5618 | +    ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; | 
|  | 5619 | +    ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; | 
|  | 5620 | + | 
|  | 5621 | +    cl_ulong offset_q = extra_q->offset + q->view_offs; | 
|  | 5622 | +    cl_ulong offset_k = extra_k->offset + k->view_offs; | 
|  | 5623 | +    cl_ulong offset_v = extra_v->offset + v->view_offs; | 
|  | 5624 | +    cl_ulong offset_o = extra_o->offset + dst->view_offs; | 
|  | 5625 | +    cl_mem   mask_buffer = extra_mask ? extra_mask->data_device : NULL; | 
|  | 5626 | +    cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; | 
|  | 5627 | + | 
|  | 5628 | +    const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; | 
|  | 5629 | +    const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; | 
|  | 5630 | +    const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; | 
|  | 5631 | +    const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; | 
|  | 5632 | +    const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; | 
|  | 5633 | +    const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; | 
|  | 5634 | +    const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; | 
|  | 5635 | +    const int mask_ne2 = mask ? mask->ne[2] : 0; | 
|  | 5636 | +    const int mask_ne3 = mask ? mask->ne[3] : 0; | 
|  | 5637 | + | 
|  | 5638 | +    float scale, max_bias, logit_softcap; | 
|  | 5639 | +    const float * params = (const float *)dst->op_params; | 
|  | 5640 | +    scale         = params[0]; | 
|  | 5641 | +    max_bias      = params[1]; | 
|  | 5642 | +    logit_softcap = params[2]; | 
|  | 5643 | + | 
|  | 5644 | +    const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); | 
|  | 5645 | + | 
|  | 5646 | +    const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; | 
|  | 5647 | +    const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; | 
|  | 5648 | +    const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); | 
|  | 5649 | +    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); | 
|  | 5650 | + | 
|  | 5651 | +    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra_q->data_device)); | 
|  | 5652 | +    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); | 
|  | 5653 | +    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra_k->data_device)); | 
|  | 5654 | +    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); | 
|  | 5655 | +    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extra_v->data_device)); | 
|  | 5656 | +    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); | 
|  | 5657 | +    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem),   &extra_o->data_device)); | 
|  | 5658 | +    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); | 
|  | 5659 | +    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float),    &scale)); | 
|  | 5660 | +    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),      &n_q)); | 
|  | 5661 | +    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),     &n_kv)); | 
|  | 5662 | +    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),     &is_causal)); | 
|  | 5663 | +    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &n_head)); | 
|  | 5664 | +    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); | 
|  | 5665 | +    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); | 
|  | 5666 | +    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); | 
|  | 5667 | +    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); | 
|  | 5668 | +    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float),    &max_bias)); | 
|  | 5669 | +    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float),    &m0)); | 
|  | 5670 | +    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &m1)); | 
|  | 5671 | +    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int),      &n_head_log2_val)); | 
|  | 5672 | +    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &logit_softcap)); | 
|  | 5673 | +    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &n_head_kv)); | 
|  | 5674 | +    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem),   &mask_buffer)); | 
|  | 5675 | +    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); | 
|  | 5676 | +    CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); | 
|  | 5677 | +    CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); | 
|  | 5678 | +    CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); | 
|  | 5679 | +    CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int),      &mask_ne2)); | 
|  | 5680 | +    CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int),      &mask_ne3)); | 
|  | 5681 | + | 
|  | 5682 | +    if (n_q == 1) { | 
|  | 5683 | +        const size_t wg_size = 64; | 
|  | 5684 | +        size_t local_work_size[] = { wg_size, 1 }; | 
|  | 5685 | +        size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; | 
|  | 5686 | +        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); | 
|  | 5687 | +    } else { | 
|  | 5688 | +        const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); | 
|  | 5689 | +        const size_t wg_size = block_m; | 
|  | 5690 | +        size_t local_work_size[] = { wg_size, 1 }; | 
|  | 5691 | +        size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; | 
|  | 5692 | +        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); | 
|  | 5693 | +    } | 
|  | 5694 | +} | 
|  | 5695 | + | 
| 5454 | 5696 | static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | 
| 5455 | 5697 |     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | 
| 5456 | 5698 | 
 | 
| @@ -7607,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor | 
| 7607 | 7849 |             } | 
| 7608 | 7850 |             func = ggml_cl_sum_rows; | 
| 7609 | 7851 |             break; | 
|  | 7852 | +        case GGML_OP_FLASH_ATTN_EXT: | 
|  | 7853 | +            if (!any_on_device) { | 
|  | 7854 | +                return false; | 
|  | 7855 | +            } | 
|  | 7856 | +            ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor); | 
|  | 7857 | +            return true; | 
| 7610 | 7858 |         default: | 
| 7611 | 7859 |             return false; | 
| 7612 | 7860 |     } | 
|  | 
0 commit comments