Skip to content

Commit df59fa2

Browse files
Gitty BursteinGittyBursteinyael-works
committed
SparseK: static mask integration in graph build (non-dynamic proof-of-concept)
Co-authored-by: Gitty Burstein <[email protected]> Co-authored-by: Yael Shuker <[email protected]>
1 parent a2f79cc commit df59fa2

File tree

2 files changed

+238
-7
lines changed

2 files changed

+238
-7
lines changed

src/llama-graph.cpp

Lines changed: 181 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
#include <cmath>
1414
#include <cstring>
1515

16+
#include <algorithm> // std::fill, std::partial_sort, std::max
17+
#include <vector>
18+
#include <utility>
19+
#include <cstdlib> // getenv, atoi
20+
21+
// forward declaration for debug printing of KQ masks
22+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv,
23+
int64_t n_swa, llama_swa_type swa_type);
24+
1625
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
1726
if (ubatch->token) {
1827
const int64_t n_tokens = ubatch->n_tokens;
@@ -261,7 +270,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
261270
}
262271
}
263272

264-
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
273+
static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv,
274+
int64_t n_swa, llama_swa_type swa_type) {
265275
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
266276
const char * swa_type_str = "unknown";
267277

@@ -296,6 +306,111 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64
296306
}
297307
}
298308

309+
// --- Implementations for llm_graph_input_sparsek_mask (declared in header) ---
310+
311+
llm_graph_input_sparsek_mask::llm_graph_input_sparsek_mask(
312+
const llama_hparams & hp,
313+
const llama_cparams & cp,
314+
const llama_ubatch & ub,
315+
const llama_kv_cache_context * mctx_)
316+
: hparams(hp), cparams(cp), ubatch(ub), mctx(mctx_) {
317+
enabled = getenv("LLAMA_SPARSEK") != nullptr;
318+
win_local = std::max(0, getenv("LLAMA_SPARSEK_WIN") ? atoi(getenv("LLAMA_SPARSEK_WIN")) : 0);
319+
stride_glob = std::max(0, getenv("LLAMA_SPARSEK_STRIDE") ? atoi(getenv("LLAMA_SPARSEK_STRIDE")) : 0);
320+
topk_static = std::max(0, getenv("LLAMA_SPARSEK_TOPK") ? atoi(getenv("LLAMA_SPARSEK_TOPK")) : 0);
321+
322+
env_enable_snap = enabled ? 1 : 0;
323+
env_win_snap = win_local;
324+
env_stride_snap = stride_glob;
325+
env_topk_snap = topk_static;
326+
}
327+
328+
void llm_graph_input_sparsek_mask::set_input(const llama_ubatch * ) {
329+
if (!enabled || !allow) return;
330+
331+
GGML_ASSERT(ggml_backend_buffer_is_host(allow->buffer));
332+
float * data = (float *) allow->data;
333+
334+
const int64_t n_stream = allow->ne[3];
335+
const int64_t n_rows = allow->ne[1];
336+
const int64_t n_kv = allow->ne[0];
337+
338+
std::fill(data, data + ggml_nelements(allow), -INFINITY);
339+
GGML_ASSERT(ubatch.pos);
340+
GGML_ASSERT(ubatch.n_tokens % n_stream == 0);
341+
342+
const int64_t n_tps = ubatch.n_tokens / n_stream;
343+
for (int64_t s = 0; s < n_stream; ++s) {
344+
for (int64_t ii = 0; ii < n_tps; ++ii) {
345+
const int64_t i = s*n_tps + ii;
346+
const int64_t row = ii;
347+
const int64_t p1 = ubatch.pos[i];
348+
float * row_ptr = data + (s*n_rows + row)*n_kv;
349+
350+
if (win_local > 0) {
351+
const int64_t lo = std::max<int64_t>(0, p1 - win_local);
352+
const int64_t hi = std::min<int64_t>(n_kv - 1, p1 + win_local);
353+
for (int64_t j = lo; j <= hi; ++j) row_ptr[j] = 0.0f;
354+
}
355+
356+
if (stride_glob > 0) {
357+
for (int64_t j = 0; j < n_kv; j += stride_glob) row_ptr[j] = 0.0f;
358+
}
359+
360+
if (topk_static > 0) {
361+
const int64_t R = std::min<int64_t>(n_kv - 1, win_local > 0 ? win_local*4 : 1024);
362+
const int64_t lo2 = std::max<int64_t>(0, p1 - R);
363+
const int64_t hi2 = std::min<int64_t>(n_kv - 1, p1 + R);
364+
365+
std::vector<std::pair<int64_t,int64_t>> cand;
366+
cand.reserve(hi2 - lo2 + 1);
367+
for (int64_t j = lo2; j <= hi2; ++j)
368+
cand.emplace_back(j, std::llabs((long long)p1 - (long long)j));
369+
370+
const size_t K = std::min<size_t>(topk_static, cand.size());
371+
std::partial_sort(cand.begin(), cand.begin() + K, cand.end(),
372+
[](auto &a, auto &b){ return a.second < b.second; });
373+
for (size_t k = 0; k < K; ++k) row_ptr[cand[k].first] = 0.0f;
374+
if (allow) {
375+
last_ne0 = allow->ne[0]; // n_kv
376+
last_ne1 = allow->ne[1]; // n_rows
377+
last_ne3 = allow->ne[3]; // n_stream
378+
}
379+
env_enable_snap = enabled ? 1 : 0;
380+
env_win_snap = win_local;
381+
env_stride_snap = stride_glob;
382+
env_topk_snap = topk_static;
383+
}
384+
}
385+
}
386+
}
387+
388+
bool llm_graph_input_sparsek_mask::can_reuse(const llm_graph_params & params) {
389+
GGML_UNUSED(params);
390+
391+
if (!allow) return false;
392+
393+
int cur_enable = getenv("LLAMA_SPARSEK") ? 1 : 0;
394+
int cur_win = getenv("LLAMA_SPARSEK_WIN") ? atoi(getenv("LLAMA_SPARSEK_WIN")) : 0;
395+
int cur_stride = getenv("LLAMA_SPARSEK_STRIDE") ? atoi(getenv("LLAMA_SPARSEK_STRIDE")) : 0;
396+
int cur_topk = getenv("LLAMA_SPARSEK_TOPK") ? atoi(getenv("LLAMA_SPARSEK_TOPK")) : 0;
397+
398+
if (cur_enable != env_enable_snap ||
399+
cur_win != env_win_snap ||
400+
cur_stride != env_stride_snap ||
401+
cur_topk != env_topk_snap) {
402+
return false;
403+
}
404+
405+
if (allow->ne[0] != last_ne0 ||
406+
allow->ne[1] != last_ne1 ||
407+
allow->ne[3] != last_ne3) {
408+
return false;
409+
}
410+
411+
return true;
412+
}
413+
299414
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
300415
const int64_t n_kv = ubatch->n_tokens;
301416
const int64_t n_tokens = ubatch->n_tokens;
@@ -600,6 +715,37 @@ void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
600715
}
601716
}
602717

718+
// ===[ SPARSEK: graph-time mask fusion ]=======================================
719+
// If disabled, returns base_mask. Otherwise builds an "allow" mask input node
720+
// and returns base_mask + allow (logical union with 0.0f / -INFINITY encoding)
721+
// so that blocked (-INF) entries remain blocked and allowed (0.0f) keep base.
722+
ggml_tensor * llm_graph_context::maybe_apply_sparsek_mask(ggml_tensor * base_mask,
723+
int64_t n_kv,
724+
int64_t n_rows,
725+
int64_t n_stream) const {
726+
const bool enabled = getenv("LLAMA_SPARSEK") != nullptr;
727+
if (!enabled) return base_mask;
728+
729+
auto inp = std::make_unique<llm_graph_input_sparsek_mask>(hparams, cparams, ubatch,
730+
static_cast<const llama_kv_cache_context *>(mctx));
731+
732+
auto & allow = inp->allow;
733+
allow = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_rows, 1, n_stream);
734+
ggml_set_input(allow);
735+
res->add_input(std::move(inp));
736+
737+
ggml_build_forward_expand(gf, allow);
738+
739+
ggml_tensor * allow_aligned = allow;
740+
if (base_mask->type != GGML_TYPE_F32) {
741+
allow_aligned = ggml_cast(ctx0, allow, base_mask->type);
742+
}
743+
744+
// Merge by logical union: allowed=0.0f, blocked=-INF
745+
ggml_tensor * merged = ggml_add(ctx0, base_mask, allow_aligned);
746+
return merged;
747+
}
748+
603749
ggml_tensor * llm_graph_context::build_cvec(
604750
ggml_tensor * cur,
605751
int il) const {
@@ -1513,7 +1659,14 @@ ggml_tensor * llm_graph_context::build_attn(
15131659

15141660
const bool is_swa = hparams.is_swa(il);
15151661

1516-
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1662+
const auto & kq_mask_base = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1663+
1664+
// no-cache: n_kv == n_tokens, stream = 1
1665+
const int64_t n_kv = ubatch.n_tokens;
1666+
const int64_t n_stream = 1;
1667+
const int64_t n_rows = GGML_PAD(ubatch.n_tokens, GGML_KQ_MASK_PAD);
1668+
1669+
ggml_tensor * kq_mask = maybe_apply_sparsek_mask(kq_mask_base, n_kv, n_rows, n_stream);
15171670

15181671
// [TAG_NO_CACHE_PAD]
15191672
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
@@ -1590,14 +1743,13 @@ ggml_tensor * llm_graph_context::build_attn(
15901743
ggml_tensor * v_mla,
15911744
float kq_scale,
15921745
int il) const {
1746+
const llama_kv_cache_context * mctx_cur = inp->mctx; // define once at top
15931747
// these nodes are added to the graph together so that they are not reordered
15941748
// by doing so, the number of splits in the graph is reduced
15951749
ggml_build_forward_expand(gf, q_cur);
15961750
ggml_build_forward_expand(gf, k_cur);
15971751
ggml_build_forward_expand(gf, v_cur);
15981752

1599-
const auto * mctx_cur = inp->mctx;
1600-
16011753
// store to KV cache
16021754
{
16031755
const auto & k_idxs = inp->get_k_idxs();
@@ -1607,7 +1759,13 @@ ggml_tensor * llm_graph_context::build_attn(
16071759
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
16081760
}
16091761

1610-
const auto & kq_mask = inp->get_kq_mask();
1762+
const auto & kq_mask_base = inp->get_kq_mask();
1763+
1764+
const int64_t n_kv = mctx_cur->get_n_kv();
1765+
const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1766+
const int64_t n_rows = GGML_PAD(ubatch.n_tokens / n_stream, GGML_KQ_MASK_PAD);
1767+
1768+
ggml_tensor * kq_mask = maybe_apply_sparsek_mask(kq_mask_base, n_kv, n_rows, n_stream);
16111769

16121770
ggml_tensor * q = q_cur;
16131771
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
@@ -1675,12 +1833,20 @@ ggml_tensor * llm_graph_context::build_attn(
16751833
}
16761834

16771835
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1836+
1837+
// --- SparseK: graph-time mask fusion for KV_ISWA ---
1838+
const int64_t n_kv = mctx_cur->get_n_kv();
1839+
const int64_t n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
1840+
const int64_t n_rows = GGML_PAD(ubatch.n_tokens / n_stream, GGML_KQ_MASK_PAD);
1841+
1842+
ggml_tensor * kq_mask_aug = maybe_apply_sparsek_mask((ggml_tensor *)kq_mask,
1843+
n_kv, n_rows, n_stream);
16781844

16791845
ggml_tensor * q = q_cur;
16801846
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
16811847
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
16821848

1683-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1849+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_aug, sinks, v_mla, kq_scale, il);
16841850
cb(cur, "kqv_out", il);
16851851

16861852
if (wo) {
@@ -1731,11 +1897,19 @@ ggml_tensor * llm_graph_context::build_attn(
17311897

17321898
const auto & kq_mask = inp->get_kq_mask_cross();
17331899

1900+
// --- SparseK: graph-time mask fusion for Cross-Attention ---
1901+
const int64_t n_kv = k_cur->ne[0]; // or cross->n_enc,
1902+
const int64_t n_stream = 1;
1903+
const int64_t n_rows = GGML_PAD(ubatch.n_tokens, GGML_KQ_MASK_PAD);
1904+
1905+
ggml_tensor * kq_mask_aug = maybe_apply_sparsek_mask((ggml_tensor *)kq_mask,
1906+
n_kv, n_rows, n_stream);
1907+
17341908
ggml_tensor * q = q_cur;
17351909
ggml_tensor * k = k_cur;
17361910
ggml_tensor * v = v_cur;
17371911

1738-
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
1912+
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_aug, sinks, v_mla, kq_scale, il);
17391913
cb(cur, "kqv_out", il);
17401914

17411915
if (wo) {

src/llama-graph.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,52 @@ class llm_graph_input_cross_embd : public llm_graph_input_i {
247247
const llama_cross * cross;
248248
};
249249

250+
// ===[ SPARSEK INPUT NODE - DECLARATION ]======================================
251+
// Provides an "allow-mask" tensor that encodes the SparseK policy
252+
// (0.0f = allowed, -INFINITY = blocked). The shape must match the KQ mask:
253+
// [ne0 = n_kv, ne1 = pad(n_tokens_per_stream), ne2 = 1, ne3 = n_stream].
254+
class llm_graph_input_sparsek_mask : public llm_graph_input_i {
255+
public:
256+
llm_graph_input_sparsek_mask(
257+
const llama_hparams & hparams,
258+
const llama_cparams & cparams,
259+
const llama_ubatch & ubatch,
260+
const llama_kv_cache_context * mctx);
261+
262+
~llm_graph_input_sparsek_mask() override = default;
263+
264+
// Populates the "allow" tensor from ubatch positions based on ENV-driven SparseK policy.
265+
// Note: definition is in the .cpp (set_input allocates/fills host-side values).
266+
void set_input(const llama_ubatch * ubatch) override;
267+
268+
// SparseK mask can be reused while the shape/involved streams are unchanged.
269+
bool can_reuse(const llm_graph_params & params) override;
270+
271+
// F32 [n_kv, pad(n_tokens_per_stream), 1, n_stream]
272+
ggml_tensor * allow = nullptr;
273+
274+
// References used to compute the mask
275+
const llama_hparams & hparams;
276+
const llama_cparams & cparams;
277+
const llama_ubatch & ubatch;
278+
const llama_kv_cache_context * mctx;
279+
280+
// ENV-driven controls (read in the .cpp)
281+
int win_local = 0;
282+
int stride_glob = 0;
283+
int topk_static = 0;
284+
bool enabled = false;
285+
286+
int64_t last_ne0 = -1; // n_kv
287+
int64_t last_ne1 = -1; // n_rows (pad(n_tokens_per_stream))
288+
int64_t last_ne3 = -1; // n_stream
289+
290+
int env_enable_snap = 0;
291+
int env_win_snap = 0;
292+
int env_stride_snap = 0;
293+
int env_topk_snap = 0;
294+
};
295+
250296
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
251297
public:
252298
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
@@ -595,6 +641,17 @@ struct llm_graph_context {
595641
// common
596642
//
597643

644+
// Merges a graph-built SparseK "allow" mask into the base KQ mask.
645+
// If SparseK is disabled (by ENV), this returns base_mask as-is.
646+
// Shapes:
647+
// base_mask : [n_kv, n_rows, 1, n_stream] or [n_tokens, n_rows, 1, n_stream] (no-cache)
648+
// allow : [n_kv, n_rows, 1, n_stream]
649+
// Returned tensor has the same shape as base_mask.
650+
ggml_tensor * maybe_apply_sparsek_mask(ggml_tensor * base_mask,
651+
int64_t n_kv,
652+
int64_t n_rows,
653+
int64_t n_stream) const;
654+
598655
ggml_tensor * build_cvec(
599656
ggml_tensor * cur,
600657
int il) const;

0 commit comments

Comments
 (0)