Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
efd9ad4
chore: ignore local backup files
Nov 11, 2025
8db1307
feat(SparseK): integrate dynamic mask build into llama-graph
Nov 11, 2025
68ab48c
remove accidental .gitignore
Nov 11, 2025
ce761f8
Without unnecessary spaces
GittyBurstein Nov 11, 2025
9d07172
restore .gitignore from upstream/master
Nov 11, 2025
af711f8
SparseK: apply review feedback (use ggml_scale_bias, single flash_att…
Nov 11, 2025
3933069
SparseK: apply review feedback (use ggml_scale_bias, single flash_att…
Nov 11, 2025
0c2dd04
fix(SparseK): use ggml_scale_bias directly on scores
GittyBurstein Nov 11, 2025
c6a5db4
restore SparseK kv-cache implementation (recovered from local file)
yael-works Nov 12, 2025
a6784f0
SparseK: update graph build — replace src/llama-graph.{h,cpp}
Nov 12, 2025
f9bd873
sparsek: finalize mask reshape and validation fixes
yael-works Nov 12, 2025
de64151
sparsek: replace ggml_scale_bias with standard ops for portability
yael-works Nov 12, 2025
08e359d
sparsek: align base mask 4D shape and add topk==0 guard for robustness
yael-works Nov 12, 2025
49a8a81
SparseK: clean dynamic mask path, remove legacy reshapes, avoid kv-ca…
yael-works Nov 13, 2025
ea21d8f
SparseK: finalize graph pipeline cleanup, remove deprecated path and …
Nov 13, 2025
161e7cd
SparseK: integrate dynamic attention mask, GGUF metadata, and model l…
yael-works Nov 13, 2025
b9a960f
SparseK: less nodes in the graph
Nov 13, 2025
b7315fc
Restore head_count block and remove incorrect SparseK metadata (per C…
yael-works Nov 13, 2025
35180a1
SparseK: fix duplicate get_key<bool> instantiations
Nov 14, 2025
2fd25a8
SparseK: don't alter KQ mask when disabled
Nov 14, 2025
5c3c65c
SparseK: do not alter KV mask when disabled
Nov 14, 2025
5798c33
Add SparseK KQ mask unit test
yael-works Nov 16, 2025
48ccccd
Clean SparseK KQ mask test and fix warnings
yael-works Nov 16, 2025
a365437
Align SparseK KV mask env gating with unit test
yael-works Nov 16, 2025
db3e875
Sparse-K: integrate graph changes and HF->GGUF metadata fixes
Nov 16, 2025
194f6a3
Merge branch 'feature/sparsek-attn-sycl' of https://github.com/yael-w…
Nov 16, 2025
60c75e7
SparseK: fix meta-buffer expansion and resolve CI failure
Nov 16, 2025
88ac1d9
SparseK: silence unused parameters in unit tests for CI
Nov 16, 2025
e6b0b10
SparseK: update reference test for kq_mask
Nov 16, 2025
46e192f
SparseK: silence release warnings in unit test helpers
Nov 16, 2025
a9d2015
SparseK: fix release warnings in unit test (assert helpers + finite_c…
Nov 16, 2025
060ee50
tests: integrate SparseK KQ mask test
yael-works Nov 17, 2025
729973b
Merge branch 'master' into feature/sparsek-attn-sycl
yael-works Nov 17, 2025
205fded
Fix duplicate get_key<bool> instantiation
Nov 17, 2025
6e36508
Remove tests/test-sparsek_kq_mask.cpp to match remote branch (resolve…
Nov 17, 2025
ed9ed7e
SparseK: Fix KQ mask test shapes to match ggml_get_rows 3D semantics
Nov 18, 2025
212d47f
SparseK: cleanup meta context and rely on graph_max_nodes headroom
Nov 18, 2025
087ecf3
SparseK: fix test-backend-ops overrides + update mask graph implement…
Nov 18, 2025
5c2849d
Remove test-backend-ops.cpp from PR
Nov 20, 2025
3687665
SparseK: fix graph node budget and stable mask construction
Nov 20, 2025
4045566
Fix flake8 E302 in convert_hf_to_gguf
Nov 20, 2025
f7b79ce
fix errors
Nov 20, 2025
d3b6c26
Fix unused variable 'picked' in SparseK mask builder
Nov 20, 2025
18adb6f
without spaces
Nov 20, 2025
57b907e
try to chek the SPARSE
Nov 21, 2025
3f1005b
mark SparseK tests as NOT_SUPPORTED on Vulkan
Nov 21, 2025
04d6c83
Merge branch 'master' into feature/sparsek-attn-sycl
yael-works Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 112 additions & 5 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,13 @@ void llm_graph_result::reset() {

inputs.clear();

buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
// increase meta buffer slightly to accommodate extra nodes from SparseK
int64_t max_nodes_ex = max_nodes + 16384; // safety headroom

buf_compute_meta.resize(
ggml_tensor_overhead() * max_nodes_ex +
ggml_graph_overhead_custom(max_nodes_ex, /*grad*/ false)
);

ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
Expand All @@ -497,7 +503,9 @@ void llm_graph_result::reset() {

ctx_compute.reset(ggml_init(params));

gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
// build graph object with the expanded node cap as well
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes_ex, /*grad*/ false);

}

void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
Expand Down Expand Up @@ -592,8 +600,25 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ctx0 (res->get_ctx()),
gf (res->get_gf()) {
res->set_params(params);
// ===[ SPARSEK: one-time env init ]===========================================
// NOTE: read once per process; used as defaults for this context.
static bool SPARSEK_ENABLE_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_ENABLE")) return atoi(s)!=0; return false; }();
static int32_t SPARSEK_TOPK_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_TOPK")) return std::max(0, atoi(s)); return 0; }();
static int32_t SPARSEK_WIN_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_WIN")) return std::max(0, atoi(s)); return 0; }();
static int32_t SPARSEK_STRIDE_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_STRIDE")) return std::max(0, atoi(s)); return 0; }();
static bool SPARSEK_EN_LOCAL_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_ENABLE_LOCAL")) return atoi(s)!=0; return true; }();
static bool SPARSEK_EN_STRIDE_INIT = [](){ if (const char* s=getenv("LLAMA_SPARSEK_ENABLE_STRIDE")) return atoi(s)!=0; return true; }();

this->sparsek_enable = SPARSEK_ENABLE_INIT;
this->sparsek_topk = SPARSEK_TOPK_INIT;
this->sparsek_win_local = SPARSEK_WIN_INIT;
this->sparsek_stride = SPARSEK_STRIDE_INIT;
this->sparsek_en_local = SPARSEK_EN_LOCAL_INIT;
this->sparsek_en_stride = SPARSEK_EN_STRIDE_INIT;
// ============================================================================
}


void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
if (cb_func) {
cb_func(ubatch, cur, name, il);
Expand Down Expand Up @@ -842,6 +867,71 @@ ggml_tensor * llm_graph_context::build_ffn(
return cur;
}

// ===[ SPARSEK: dynamic mask builders ]=======================================
ggml_tensor * llm_graph_context::build_sparsek_mask(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * base_mask,
int il) const {
// If features are disabled, return base mask as-is.
if (!sparsek_enable || sparsek_topk <= 0) {
cb(base_mask, "sparsek_passthrough_base", il);
return base_mask;
}

// Base dims (follow base_mask layout)
const int64_t n_kv = base_mask->ne[0];
const int64_t n_rows_p = base_mask->ne[1];

// 1) Compute content-based scores ~ K * Q, reshape to [n_kv, n_rows_p]
ggml_tensor * qt = ggml_reshape_2d(ctx0, q, q->ne[0], q->ne[1]); // flatten-per-head view
ggml_tensor * kt = ggml_reshape_2d(ctx0, k, k->ne[0], k->ne[1]);
ggml_tensor * scores = ggml_mul_mat(ctx0, kt, qt); // [?, ?]
scores = ggml_reshape_2d(ctx0, scores, n_kv, n_rows_p);
cb(scores, "sparsek_scores", il);

// 2) Top-K indices along dim-0 (per column)
ggml_tensor * topk_idx = ggml_top_k(ctx0, scores, sparsek_topk); // [topk, n_rows_p]
cb(topk_idx, "sparsek_topk_idx", il);

// 3) Build -INF base of shape [n_kv, 1, n_rows_p]
// Create a zero tensor same shape as 'scores', then bias it to -INF using ggml_scale_bias
ggml_tensor * zero2d = ggml_dup(ctx0, scores); // [n_kv, n_rows_p]
ggml_set_f32(zero2d, 0.0f); // fill zeros
ggml_tensor * neg2d = ggml_scale_bias(ctx0, zero2d,
/*scale=*/0.0f,
/*bias =*/-INFINITY); // 0*X + (-INF) = -INF
ggml_tensor * rows3d = ggml_reshape_3d(ctx0, neg2d, n_kv, 1, n_rows_p); // [n_kv,1,n_rows_p]
ggml_tensor * picked = ggml_get_rows(ctx0, rows3d, topk_idx); // [topk,1,n_rows_p]
ggml_tensor * zeros = ggml_scale(ctx0, picked, 0.0f); // make selected rows = 0
ggml_tensor * merged = ggml_set_rows(ctx0, neg2d, zeros, topk_idx); // scatter zeros into -INF base
ggml_tensor * allow = ggml_reshape_4d(ctx0, merged, n_kv, n_rows_p, 1, 1); // [n_kv,n_rows_p,1,1]
cb(allow, "sparsek_allow_topk_only", il);

// 4) Final union with base (0/-INF encoding)
ggml_tensor * final_mask = ggml_add(ctx0, base_mask, allow);
cb(final_mask, "sparsek_final_mask", il);
return final_mask;
}

ggml_tensor * llm_graph_context::maybe_apply_sparsek_mask(
ggml_tensor * base_mask,
ggml_tensor * q,
ggml_tensor * k,
int64_t n_kv,
int64_t n_rows,
int64_t n_stream,
int il) const {
GGML_UNUSED(n_kv); GGML_UNUSED(n_rows); GGML_UNUSED(n_stream);
// If disabled, keep base behavior.
if (!sparsek_enable && !sparsek_en_local && !sparsek_en_stride) {
return base_mask;
}
// Build dynamic Sparse-K mask and union with base:
return build_sparsek_mask(q, k, base_mask, il);
}
// ============================================================================

ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
Expand Down Expand Up @@ -1374,8 +1464,25 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
}

cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
// SPARSEK: build final KQ mask once (union with base 0/-INF)
ggml_tensor * kq_mask_final = maybe_apply_sparsek_mask(
/*base_mask=*/kq_mask,
/*q=*/q,
/*k=*/k,
/*n_kv=*/kq_mask->ne[0],
/*n_rows=*/kq_mask->ne[1],
/*n_stream=*/kq_mask->ne[3],
/*il=*/il);

// Single flash-attn call using the final mask
cur = ggml_flash_attn_ext(
ctx0, q, k, v,
/*kq_mask=*/kq_mask_final,
kq_scale,
hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f
);

cb(cur, LLAMA_TENSOR_NAME_FATTN, il);

ggml_flash_attn_ext_add_sinks(cur, sinks);
Expand Down Expand Up @@ -1959,7 +2066,7 @@ void llm_graph_context::build_pooling(

GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");

ggml_tensor * cur;
ggml_tensor * cur = nullptr; // ensure initialized

switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
Expand Down
31 changes: 31 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,37 @@ struct llm_graph_context {
// common
//

// ===[ SPARSEK: config & builders ]===========================================
// Runtime config toggles (filled in .cpp constructor; env or defaults)
bool sparsek_enable = false; // enable/disable dynamic Sparse-K
int32_t sparsek_topk = 0; // top-K per row (0 -> disabled unless window/stride applies)
int32_t sparsek_win_local = 0; // local window radius (tokens to each side)
int32_t sparsek_stride = 0; // global stride period
bool sparsek_en_local = true; // enable local window
bool sparsek_en_stride = true; // enable global stride

// Build a dynamic Sparse-K mask inside the compute graph.
// q, k: projected tensors (per-head layout consistent with current layer)
// base_mask: the pre-existing KQ mask (causal/cross/SWA) encoded as 0 / -INF
// il: layer index for cb(...) tracing
ggml_tensor * build_sparsek_mask(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * base_mask,
int il) const;

// Apply Sparse-K on top of an existing base mask when enabled.
// n_kv / n_rows / n_stream are used to validate/reshape mask layout.
ggml_tensor * maybe_apply_sparsek_mask(
ggml_tensor * base_mask,
ggml_tensor * q,
ggml_tensor * k,
int64_t n_kv,
int64_t n_rows,
int64_t n_stream,
int il) const;
// ============================================================================

ggml_tensor * build_cvec(
ggml_tensor * cur,
int il) const;
Expand Down
Loading