@@ -7930,194 +7930,6 @@ void ggml_compute_forward_argsort(
79307930 }
79317931}
79327932
7933- // ------------------------------------------------------------------------------
7934- // SparseK Attention (CPU, final optimized version)
7935- // ------------------------------------------------------------------------------
7936- //
7937- // Implements SparseK Attention as a GGML operator for the CPU backend.
7938- // Features:
7939- // • Top-K filtering using nth_element (O(N))
7940- // • Optional local window (win_local)
7941- // • Optional global stride (stride_glb)
7942- // • Numerically stable softmax
7943- // • Preallocated buffers for performance
7944- //
7945- // Author: Yael Shuker & Gitty Burstein
7946- // ------------------------------------------------------------------------------
7947-
7948- #include < algorithm>
7949- #include < vector>
7950- #include < cmath>
7951- #include < limits>
7952-
7953- static void ggml_compute_forward_sparsek_attn_f32 (
7954- const struct ggml_compute_params * params,
7955- struct ggml_tensor * dst) {
7956-
7957- // Single-threaded baseline version
7958- if (params->ith != 0 ) return ;
7959-
7960- const struct ggml_tensor * Q = dst->src [0 ];
7961- const struct ggml_tensor * K = dst->src [1 ];
7962- const struct ggml_tensor * V = dst->src [2 ];
7963-
7964- GGML_ASSERT (Q && K && V);
7965- GGML_ASSERT (Q->type == GGML_TYPE_F32);
7966- GGML_ASSERT (K->type == GGML_TYPE_F32);
7967- GGML_ASSERT (V->type == GGML_TYPE_F32);
7968- GGML_ASSERT (dst->type == GGML_TYPE_F32);
7969-
7970- // Operator parameters
7971- const int32_t k_top = ggml_get_op_params_i32 (dst, 0 );
7972- const int32_t win_local = ggml_get_op_params_i32 (dst, 1 ); // -1 ⇒ no local window
7973- const int32_t stride_glb = ggml_get_op_params_i32 (dst, 2 ); // ≤1 ⇒ no global stride
7974-
7975- const bool use_local = (win_local >= 0 );
7976- const bool use_stride = (stride_glb > 1 );
7977-
7978- // GGML tensor dimensions: ne[0]=D, ne[1]=T, ne[2]=H, ne[3]=B
7979- const int64_t D = Q->ne [0 ];
7980- const int64_t T = Q->ne [1 ];
7981- const int64_t H = Q->ne [2 ];
7982- const int64_t B = Q->ne [3 ];
7983-
7984- // Dimension validation
7985- GGML_ASSERT (K->ne [0 ] == D && V->ne [0 ] == D);
7986- GGML_ASSERT (K->ne [1 ] == T && V->ne [1 ] == T);
7987- GGML_ASSERT (K->ne [2 ] == H && V->ne [2 ] == H);
7988- GGML_ASSERT (K->ne [3 ] == B && V->ne [3 ] == B);
7989-
7990- // Parameter sanity checks
7991- GGML_ASSERT (k_top >= 0 && k_top <= (int32_t )T);
7992- GGML_ASSERT (win_local >= -1 );
7993- GGML_ASSERT (stride_glb >= 0 );
7994-
7995- const float scale = 1 .0f / sqrtf ((float )D);
7996- const float NINF = -std::numeric_limits<float >::infinity ();
7997-
7998- // Preallocated buffers to avoid heap churn
7999- std::vector<float > attn_row ((size_t )T, NINF);
8000- std::vector<int32_t > cand_idx; cand_idx.reserve ((size_t )T);
8001- std::vector<float > scores; scores.reserve ((size_t )T);
8002-
8003- for (int64_t b = 0 ; b < B; ++b) {
8004- for (int64_t h = 0 ; h < H; ++h) {
8005- for (int64_t iq = 0 ; iq < T; ++iq) {
8006-
8007- // (0) Build candidate index list (always include self)
8008- cand_idx.clear ();
8009- scores.clear ();
8010-
8011- if (!use_local && !use_stride) {
8012- // No sparsity: attend to all tokens
8013- for (int64_t j = 0 ; j < T; ++j)
8014- cand_idx.push_back ((int32_t )j);
8015- } else {
8016- // Apply local window and/or global stride
8017- for (int64_t j = 0 ; j < T; ++j) {
8018- const int64_t dist = iq >= j ? iq - j : j - iq;
8019- const bool pass_local = use_local && (dist <= (int64_t )win_local);
8020- const bool pass_stride = use_stride && (stride_glb > 0 && j % stride_glb == 0 );
8021- if (pass_local || pass_stride || j == iq)
8022- cand_idx.push_back ((int32_t )j);
8023- }
8024- }
8025-
8026- // Edge case: no candidates or k_top==0 → output zeros
8027- if (k_top == 0 || cand_idx.empty ()) {
8028- float * y0 = (float *)((char *)dst->data + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
8029- std::fill (y0, y0 + D, 0 .0f );
8030- continue ;
8031- }
8032-
8033- // (1) Compute scaled dot-product Q·K only for candidates
8034- std::fill (attn_row.begin (), attn_row.end (), NINF);
8035- const float * qv = (const float *)((const char *)Q->data + b*Q->nb [3 ] + h*Q->nb [2 ] + iq*Q->nb [1 ]);
8036-
8037- for (int32_t j : cand_idx) {
8038- const float * kv = (const float *)((const char *)K->data + b*K->nb [3 ] + h*K->nb [2 ] + (int64_t )j*K->nb [1 ]);
8039- float dot = 0 .0f ;
8040- for (int64_t d = 0 ; d < D; ++d)
8041- dot += qv[d] * kv[d];
8042- attn_row[j] = dot * scale;
8043- }
8044-
8045- // (2) Determine true Top-K threshold using nth_element
8046- const int num_candidates = (int )cand_idx.size ();
8047- const int kk = std::min<int >(std::max<int >(1 , k_top), num_candidates);
8048-
8049- if (kk < num_candidates) {
8050- scores.resize ((size_t )num_candidates);
8051- for (size_t i = 0 ; i < cand_idx.size (); ++i)
8052- scores[i] = attn_row[cand_idx[i]];
8053-
8054- std::nth_element (scores.begin (), scores.begin () + (kk - 1 ), scores.end (), std::greater<float >());
8055- const float thr = scores[kk - 1 ];
8056-
8057- // Mask all values below the threshold
8058- for (int32_t j : cand_idx)
8059- if (attn_row[j] < thr) attn_row[j] = NINF;
8060- }
8061-
8062- // (3) Numerically stable softmax
8063- float maxv = NINF;
8064- for (int32_t j : cand_idx)
8065- maxv = std::max (maxv, attn_row[j]);
8066-
8067- // Handle all-masked rows
8068- if (!std::isfinite (maxv)) {
8069- float * y0 = (float *)((char *)dst->data + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
8070- std::fill (y0, y0 + D, 0 .0f );
8071- continue ;
8072- }
8073-
8074- float sum = 0 .0f ;
8075- for (int32_t j : cand_idx) {
8076- if (attn_row[j] == NINF) continue ;
8077- const float e = expf (attn_row[j] - maxv);
8078- attn_row[j] = e;
8079- sum += e;
8080- }
8081-
8082- const float inv_sum = (sum > 0 .0f ) ? (1 .0f / sum) : 0 .0f ;
8083- for (int32_t j : cand_idx) {
8084- if (attn_row[j] == NINF) continue ;
8085- attn_row[j] *= inv_sum;
8086- }
8087-
8088- // (4) Compute output y = A·V
8089- float * y = (float *)((char *)dst->data + b*dst->nb [3 ] + h*dst->nb [2 ] + iq*dst->nb [1 ]);
8090- for (int64_t d = 0 ; d < D; ++d) {
8091- float acc = 0 .0f ;
8092- for (int32_t j : cand_idx) {
8093- const float aij = attn_row[j];
8094- if (!(aij > 0 .0f )) continue ; // skip zero or masked
8095- const float * vv = (const float *)((const char *)V->data + b*V->nb [3 ] + h*V->nb [2 ] + (int64_t )j*V->nb [1 ]);
8096- acc += aij * vv[d];
8097- }
8098- y[d] = acc;
8099- }
8100- }
8101- }
8102- }
8103-
8104- GGML_PRINT_DEBUG (" [SPARSEK CPU] k_top=%d win_local=%d stride=%d\n " ,
8105- k_top, win_local, stride_glb);
8106- }
8107-
8108- void ggml_compute_forward_sparsek_attn (
8109- const struct ggml_compute_params * params,
8110- struct ggml_tensor * dst) {
8111- switch (dst->type ) {
8112- case GGML_TYPE_F32:
8113- ggml_compute_forward_sparsek_attn_f32 (params, dst);
8114- break ;
8115- default :
8116- GGML_ASSERT (false && " sparsek_attn: unsupported dst type" );
8117- }
8118- }
8119-
8120-
81217933// ggml_compute_forward_flash_attn_ext
81227934
81237935static void ggml_compute_forward_flash_attn_ext_f16_one_chunk (
0 commit comments