1313#include < cmath>
1414#include < cstring>
1515
16+ #include < algorithm> // std::fill, std::partial_sort, std::max
17+ #include < vector>
18+ #include < utility>
19+ #include < cstdlib> // getenv, atoi
20+
21+ // forward declaration for debug printing of KQ masks
22+ static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv,
23+ int64_t n_swa, llama_swa_type swa_type);
24+
1625void llm_graph_input_embd::set_input (const llama_ubatch * ubatch) {
1726 if (ubatch->token ) {
1827 const int64_t n_tokens = ubatch->n_tokens ;
@@ -261,7 +270,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261270 }
262271}
263272
264- static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
273+ static void print_mask (const float * data, int64_t n_tokens, int64_t n_kv,
274+ int64_t n_swa, llama_swa_type swa_type) {
265275 LLAMA_LOG_DEBUG (" %s: === Attention mask ===\n " , __func__);
266276 const char * swa_type_str = " unknown" ;
267277
@@ -296,6 +306,111 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
296306 }
297307}
298308
309+ // --- Implementations for llm_graph_input_sparsek_mask (declared in header) ---
310+
311+ llm_graph_input_sparsek_mask::llm_graph_input_sparsek_mask (
312+ const llama_hparams & hp,
313+ const llama_cparams & cp,
314+ const llama_ubatch & ub,
315+ const llama_kv_cache_context * mctx_)
316+ : hparams(hp), cparams(cp), ubatch(ub), mctx(mctx_) {
317+ enabled = getenv (" LLAMA_SPARSEK" ) != nullptr ;
318+ win_local = std::max (0 , getenv (" LLAMA_SPARSEK_WIN" ) ? atoi (getenv (" LLAMA_SPARSEK_WIN" )) : 0 );
319+ stride_glob = std::max (0 , getenv (" LLAMA_SPARSEK_STRIDE" ) ? atoi (getenv (" LLAMA_SPARSEK_STRIDE" )) : 0 );
320+ topk_static = std::max (0 , getenv (" LLAMA_SPARSEK_TOPK" ) ? atoi (getenv (" LLAMA_SPARSEK_TOPK" )) : 0 );
321+
322+ env_enable_snap = enabled ? 1 : 0 ;
323+ env_win_snap = win_local;
324+ env_stride_snap = stride_glob;
325+ env_topk_snap = topk_static;
326+ }
327+
328+ void llm_graph_input_sparsek_mask::set_input (const llama_ubatch * ) {
329+ if (!enabled || !allow) return ;
330+
331+ GGML_ASSERT (ggml_backend_buffer_is_host (allow->buffer ));
332+ float * data = (float *) allow->data ;
333+
334+ const int64_t n_stream = allow->ne [3 ];
335+ const int64_t n_rows = allow->ne [1 ];
336+ const int64_t n_kv = allow->ne [0 ];
337+
338+ std::fill (data, data + ggml_nelements (allow), -INFINITY);
339+ GGML_ASSERT (ubatch.pos );
340+ GGML_ASSERT (ubatch.n_tokens % n_stream == 0 );
341+
342+ const int64_t n_tps = ubatch.n_tokens / n_stream;
343+ for (int64_t s = 0 ; s < n_stream; ++s) {
344+ for (int64_t ii = 0 ; ii < n_tps; ++ii) {
345+ const int64_t i = s*n_tps + ii;
346+ const int64_t row = ii;
347+ const int64_t p1 = ubatch.pos [i];
348+ float * row_ptr = data + (s*n_rows + row)*n_kv;
349+
350+ if (win_local > 0 ) {
351+ const int64_t lo = std::max<int64_t >(0 , p1 - win_local);
352+ const int64_t hi = std::min<int64_t >(n_kv - 1 , p1 + win_local);
353+ for (int64_t j = lo; j <= hi; ++j) row_ptr[j] = 0 .0f ;
354+ }
355+
356+ if (stride_glob > 0 ) {
357+ for (int64_t j = 0 ; j < n_kv; j += stride_glob) row_ptr[j] = 0 .0f ;
358+ }
359+
360+ if (topk_static > 0 ) {
361+ const int64_t R = std::min<int64_t >(n_kv - 1 , win_local > 0 ? win_local*4 : 1024 );
362+ const int64_t lo2 = std::max<int64_t >(0 , p1 - R);
363+ const int64_t hi2 = std::min<int64_t >(n_kv - 1 , p1 + R);
364+
365+ std::vector<std::pair<int64_t ,int64_t >> cand;
366+ cand.reserve (hi2 - lo2 + 1 );
367+ for (int64_t j = lo2; j <= hi2; ++j)
368+ cand.emplace_back (j, std::llabs ((long long )p1 - (long long )j));
369+
370+ const size_t K = std::min<size_t >(topk_static, cand.size ());
371+ std::partial_sort (cand.begin (), cand.begin () + K, cand.end (),
372+ [](auto &a, auto &b){ return a.second < b.second ; });
373+ for (size_t k = 0 ; k < K; ++k) row_ptr[cand[k].first ] = 0 .0f ;
374+ if (allow) {
375+ last_ne0 = allow->ne [0 ]; // n_kv
376+ last_ne1 = allow->ne [1 ]; // n_rows
377+ last_ne3 = allow->ne [3 ]; // n_stream
378+ }
379+ env_enable_snap = enabled ? 1 : 0 ;
380+ env_win_snap = win_local;
381+ env_stride_snap = stride_glob;
382+ env_topk_snap = topk_static;
383+ }
384+ }
385+ }
386+ }
387+
388+ bool llm_graph_input_sparsek_mask::can_reuse (const llm_graph_params & params) {
389+ GGML_UNUSED (params);
390+
391+ if (!allow) return false ;
392+
393+ int cur_enable = getenv (" LLAMA_SPARSEK" ) ? 1 : 0 ;
394+ int cur_win = getenv (" LLAMA_SPARSEK_WIN" ) ? atoi (getenv (" LLAMA_SPARSEK_WIN" )) : 0 ;
395+ int cur_stride = getenv (" LLAMA_SPARSEK_STRIDE" ) ? atoi (getenv (" LLAMA_SPARSEK_STRIDE" )) : 0 ;
396+ int cur_topk = getenv (" LLAMA_SPARSEK_TOPK" ) ? atoi (getenv (" LLAMA_SPARSEK_TOPK" )) : 0 ;
397+
398+ if (cur_enable != env_enable_snap ||
399+ cur_win != env_win_snap ||
400+ cur_stride != env_stride_snap ||
401+ cur_topk != env_topk_snap) {
402+ return false ;
403+ }
404+
405+ if (allow->ne [0 ] != last_ne0 ||
406+ allow->ne [1 ] != last_ne1 ||
407+ allow->ne [3 ] != last_ne3) {
408+ return false ;
409+ }
410+
411+ return true ;
412+ }
413+
299414void llm_graph_input_attn_no_cache::set_input (const llama_ubatch * ubatch) {
300415 const int64_t n_kv = ubatch->n_tokens ;
301416 const int64_t n_tokens = ubatch->n_tokens ;
@@ -600,6 +715,37 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
600715 }
601716}
602717
718+ // ===[ SPARSEK: graph-time mask fusion ]=======================================
719+ // If disabled, returns base_mask. Otherwise builds an "allow" mask input node
720+ // and returns base_mask + allow (logical union with 0.0f / -INFINITY encoding)
721+ // so that blocked (-INF) entries remain blocked and allowed (0.0f) keep base.
722+ ggml_tensor * llm_graph_context::maybe_apply_sparsek_mask (ggml_tensor * base_mask,
723+ int64_t n_kv,
724+ int64_t n_rows,
725+ int64_t n_stream) const {
726+ const bool enabled = getenv (" LLAMA_SPARSEK" ) != nullptr ;
727+ if (!enabled) return base_mask;
728+
729+ auto inp = std::make_unique<llm_graph_input_sparsek_mask>(hparams, cparams, ubatch,
730+ static_cast <const llama_kv_cache_context *>(mctx));
731+
732+ auto & allow = inp->allow ;
733+ allow = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, n_rows, 1 , n_stream);
734+ ggml_set_input (allow);
735+ res->add_input (std::move (inp));
736+
737+ ggml_build_forward_expand (gf, allow);
738+
739+ ggml_tensor * allow_aligned = allow;
740+ if (base_mask->type != GGML_TYPE_F32) {
741+ allow_aligned = ggml_cast (ctx0, allow, base_mask->type );
742+ }
743+
744+ // Merge by logical union: allowed=0.0f, blocked=-INF
745+ ggml_tensor * merged = ggml_add (ctx0, base_mask, allow_aligned);
746+ return merged;
747+ }
748+
603749ggml_tensor * llm_graph_context::build_cvec (
604750 ggml_tensor * cur,
605751 int il) const {
@@ -1513,7 +1659,14 @@ ggml_tensor * llm_graph_context::build_attn(
15131659
15141660 const bool is_swa = hparams.is_swa (il);
15151661
1516- const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1662+ const auto & kq_mask_base = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1663+
1664+ // no-cache: n_kv == n_tokens, stream = 1
1665+ const int64_t n_kv = ubatch.n_tokens ;
1666+ const int64_t n_stream = 1 ;
1667+ const int64_t n_rows = GGML_PAD (ubatch.n_tokens , GGML_KQ_MASK_PAD);
1668+
1669+ ggml_tensor * kq_mask = maybe_apply_sparsek_mask (kq_mask_base, n_kv, n_rows, n_stream);
15171670
15181671 // [TAG_NO_CACHE_PAD]
15191672 // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
@@ -1590,14 +1743,13 @@ ggml_tensor * llm_graph_context::build_attn(
15901743 ggml_tensor * v_mla,
15911744 float kq_scale,
15921745 int il) const {
1746+ const llama_kv_cache_context * mctx_cur = inp->mctx ; // define once at top
15931747 // these nodes are added to the graph together so that they are not reordered
15941748 // by doing so, the number of splits in the graph is reduced
15951749 ggml_build_forward_expand (gf, q_cur);
15961750 ggml_build_forward_expand (gf, k_cur);
15971751 ggml_build_forward_expand (gf, v_cur);
15981752
1599- const auto * mctx_cur = inp->mctx ;
1600-
16011753 // store to KV cache
16021754 {
16031755 const auto & k_idxs = inp->get_k_idxs ();
@@ -1607,7 +1759,13 @@ ggml_tensor * llm_graph_context::build_attn(
16071759 ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
16081760 }
16091761
1610- const auto & kq_mask = inp->get_kq_mask ();
1762+ const auto & kq_mask_base = inp->get_kq_mask ();
1763+
1764+ const int64_t n_kv = mctx_cur->get_n_kv ();
1765+ const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1766+ const int64_t n_rows = GGML_PAD (ubatch.n_tokens / n_stream, GGML_KQ_MASK_PAD);
1767+
1768+ ggml_tensor * kq_mask = maybe_apply_sparsek_mask (kq_mask_base, n_kv, n_rows, n_stream);
16111769
16121770 ggml_tensor * q = q_cur;
16131771 ggml_tensor * k = mctx_cur->get_k (ctx0, il);
@@ -1675,12 +1833,20 @@ ggml_tensor * llm_graph_context::build_attn(
16751833 }
16761834
16771835 const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
1836+
1837+ // --- SparseK: graph-time mask fusion for KV_ISWA ---
1838+ const int64_t n_kv = mctx_cur->get_n_kv ();
1839+ const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1840+ const int64_t n_rows = GGML_PAD (ubatch.n_tokens / n_stream, GGML_KQ_MASK_PAD);
1841+
1842+ ggml_tensor * kq_mask_aug = maybe_apply_sparsek_mask ((ggml_tensor *)kq_mask,
1843+ n_kv, n_rows, n_stream);
16781844
16791845 ggml_tensor * q = q_cur;
16801846 ggml_tensor * k = mctx_cur->get_k (ctx0, il);
16811847 ggml_tensor * v = mctx_cur->get_v (ctx0, il);
16821848
1683- ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask , sinks, v_mla, kq_scale, il);
1849+ ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask_aug , sinks, v_mla, kq_scale, il);
16841850 cb (cur, " kqv_out" , il);
16851851
16861852 if (wo) {
@@ -1731,11 +1897,19 @@ ggml_tensor * llm_graph_context::build_attn(
17311897
17321898 const auto & kq_mask = inp->get_kq_mask_cross ();
17331899
1900+ // --- SparseK: graph-time mask fusion for Cross-Attention ---
1901+ const int64_t n_kv = k_cur->ne [0 ]; // or cross->n_enc,
1902+ const int64_t n_stream = 1 ;
1903+ const int64_t n_rows = GGML_PAD (ubatch.n_tokens , GGML_KQ_MASK_PAD);
1904+
1905+ ggml_tensor * kq_mask_aug = maybe_apply_sparsek_mask ((ggml_tensor *)kq_mask,
1906+ n_kv, n_rows, n_stream);
1907+
17341908 ggml_tensor * q = q_cur;
17351909 ggml_tensor * k = k_cur;
17361910 ggml_tensor * v = v_cur;
17371911
1738- ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask , sinks, v_mla, kq_scale, il);
1912+ ggml_tensor * cur = build_attn_mha (q, k, v, kq_b, kq_mask_aug , sinks, v_mla, kq_scale, il);
17391913 cb (cur, " kqv_out" , il);
17401914
17411915 if (wo) {
0 commit comments