Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
efd9ad4
chore: ignore local backup files
Nov 11, 2025
8db1307
feat(SparseK): integrate dynamic mask build into llama-graph
Nov 11, 2025
68ab48c
remove accidental .gitignore
Nov 11, 2025
ce761f8
Without unnecessary spaces
GittyBurstein Nov 11, 2025
9d07172
restore .gitignore from upstream/master
Nov 11, 2025
af711f8
SparseK: apply review feedback (use ggml_scale_bias, single flash_att…
Nov 11, 2025
3933069
SparseK: apply review feedback (use ggml_scale_bias, single flash_att…
Nov 11, 2025
0c2dd04
fix(SparseK): use ggml_scale_bias directly on scores
GittyBurstein Nov 11, 2025
c6a5db4
restore SparseK kv-cache implementation (recovered from local file)
yael-works Nov 12, 2025
a6784f0
SparseK: update graph build — replace src/llama-graph.{h,cpp}
Nov 12, 2025
f9bd873
sparsek: finalize mask reshape and validation fixes
yael-works Nov 12, 2025
de64151
sparsek: replace ggml_scale_bias with standard ops for portability
yael-works Nov 12, 2025
08e359d
sparsek: align base mask 4D shape and add topk==0 guard for robustness
yael-works Nov 12, 2025
49a8a81
SparseK: clean dynamic mask path, remove legacy reshapes, avoid kv-ca…
yael-works Nov 13, 2025
ea21d8f
SparseK: finalize graph pipeline cleanup, remove deprecated path and …
Nov 13, 2025
161e7cd
SparseK: integrate dynamic attention mask, GGUF metadata, and model l…
yael-works Nov 13, 2025
b9a960f
SparseK: less nodes in the graph
Nov 13, 2025
b7315fc
Restore head_count block and remove incorrect SparseK metadata (per C…
yael-works Nov 13, 2025
35180a1
SparseK: fix duplicate get_key<bool> instantiations
Nov 14, 2025
2fd25a8
SparseK: don't alter KQ mask when disabled
Nov 14, 2025
5c3c65c
SparseK: do not alter KV mask when disabled
Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,15 +794,12 @@ def set_gguf_parameters(self):
if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
logger.info(f"gguf: feed forward length = {n_ff}")

if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None:
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")

if (n_head_kv := self.find_hparam(["num_key_value_heads", "n_kv_heads"], optional=True)) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}")

if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
Expand Down
176 changes: 171 additions & 5 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,13 @@ void llm_graph_result::reset() {

inputs.clear();

buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
// increase meta buffer slightly to accommodate extra nodes from SparseK
int64_t max_nodes_ex = max_nodes + 16384; // safety headroom

buf_compute_meta.resize(
ggml_tensor_overhead() * max_nodes_ex +
ggml_graph_overhead_custom(max_nodes_ex, /*grad*/ false)
);

ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
Expand All @@ -497,7 +503,9 @@ void llm_graph_result::reset() {

ctx_compute.reset(ggml_init(params));

gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
// build graph object with the expanded node cap as well
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes_ex, /*grad*/ false);

}

void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
Expand Down Expand Up @@ -592,8 +600,20 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ctx0 (res->get_ctx()),
gf (res->get_gf()) {
res->set_params(params);
// === SparseK: load from model metadata (no env vars) =========================
this->sparsek_enable = hparams.sparsek_enable;
this->sparsek_topk = hparams.sparsek_topk;
this->sparsek_win_local = hparams.sparsek_window;
this->sparsek_stride = hparams.sparsek_stride;

// Default gating (until model metadata defines its own)
this->sparsek_en_local = true;
this->sparsek_en_stride = true;
// ============================================================================

}


void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
if (cb_func) {
cb_func(ubatch, cur, name, il);
Expand Down Expand Up @@ -842,6 +862,135 @@ ggml_tensor * llm_graph_context::build_ffn(
return cur;
}

// ===[ SPARSEK: dynamic mask builders ]=======================================
ggml_tensor * llm_graph_context::build_sparsek_mask(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * base_mask,
int il) const {

// If features are disabled, return base mask as-is.
if (!sparsek_enable || sparsek_topk <= 0) {
cb(base_mask, "sparsek_passthrough_base", il);
return base_mask;
}

// ---------------------------------------------------------------------
// 0) Derive layout from base_mask first (cheaper / more robust).
// base_mask is assumed to be [n_kv, n_rows, n_head, n_stream].
// ---------------------------------------------------------------------
const int64_t n_kv = base_mask->ne[0];
const int64_t n_rows = base_mask->ne[1];
const int64_t n_head = std::max<int64_t>(1, base_mask->ne[2]);
const int64_t n_stream= std::max<int64_t>(1, base_mask->ne[3]);
const int64_t hs = n_head * n_stream; // heads * streams

if (n_rows <= 0 || hs <= 0) {
cb(base_mask, "sparsek_invalid_base_layout_passthrough", il);
return base_mask;
}

// ---------------------------------------------------------------------
// 1) Compute content-based scores ~ K * Q on current 4D layout.
// Result is [n_kv, n_rows, n_head, n_stream] or compatible.
// ---------------------------------------------------------------------
ggml_tensor * scores4 = ggml_mul_mat(ctx0, k, q);
cb(scores4, "sparsek_scores4_raw", il);

// Make contiguous only if required by later reshape.
if (!ggml_is_contiguous(scores4)) {
scores4 = ggml_cont(ctx0, scores4);
}

// Flatten head/stream dimensions into column dimension.
// We want scores2d = [n_kv, n_rows * hs].
const int64_t cols_calc = n_rows * hs;
ggml_tensor * scores2d = ggml_reshape_2d(ctx0, scores4, n_kv, cols_calc);
cb(scores2d, "sparsek_scores2d", il);

// ---------------------------------------------------------------------
// 2) Top-K indices along dim-0 (per column).
// ---------------------------------------------------------------------
const int32_t topk_safe =
std::max<int32_t>(0, std::min<int32_t>(sparsek_topk, (int32_t) n_kv));
if (topk_safe == 0) {
cb(base_mask, "sparsek_topk_zero_passthrough", il);
return base_mask;
}

ggml_tensor * topk_idx = ggml_top_k(ctx0, scores2d, topk_safe); // [topk, cols_calc]
cb(topk_idx, "sparsek_topk_idx", il);

// ---------------------------------------------------------------------
// 3) Build SparseK mask:
// Start from all -INF [n_kv, cols_calc] then set selected rows to 0.
// We avoid using "scores2d" as input to scale_bias to reduce
// unnecessary dataflow dependencies.
// ---------------------------------------------------------------------
ggml_tensor * neg2d = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, cols_calc);
ggml_set_f32(neg2d, -INFINITY); // constant -INF

ggml_tensor * rows3d = ggml_reshape_3d(ctx0, neg2d, n_kv, 1, cols_calc); // [n_kv, 1, cols]
ggml_tensor * picked = ggml_get_rows(ctx0, rows3d, topk_idx); // [topk, 1, cols]
ggml_tensor * zeros = ggml_scale(ctx0, picked, 0.0f); // [topk, 1, cols] = 0
ggml_tensor * merged3d = ggml_set_rows(ctx0, rows3d, zeros, topk_idx); // [n_kv, 1, cols]

// ---------------------------------------------------------------------
// 4) Broadcast into [n_kv, n_rows, hs] and combine with base_mask.
// ---------------------------------------------------------------------
ggml_tensor * mask3 = ggml_reshape_3d(ctx0, merged3d, n_kv, n_rows, hs);
cb(mask3, "sparsek_allow_topk_only", il);

// base2d: [n_kv, n_rows]
ggml_tensor * base2d = ggml_reshape_2d(ctx0, base_mask, n_kv, n_rows);

// Safety check: rows must match.
if (base2d->ne[0] != n_kv || base2d->ne[1] != n_rows) {
cb(base_mask, "sparsek_kv_or_rows_mismatch_passthrough", il);
return base_mask;
}

// Broadcast base_mask into [n_kv, n_rows, hs].
ggml_tensor * base3 = ggml_reshape_3d(ctx0, base2d, n_kv, n_rows, 1);
ggml_tensor * base_rep = ggml_repeat(ctx0, base3, mask3); // [n_kv, n_rows, hs]

// Combine SparseK and base (0 / -INF encoding).
ggml_tensor * final3 = ggml_add(ctx0, mask3, base_rep); // [n_kv, n_rows, hs]

// ---------------------------------------------------------------------
// 5) Reshape back to original 4D layout.
// ---------------------------------------------------------------------
ggml_tensor * final_mask = ggml_reshape_4d(
ctx0,
final3,
base_mask->ne[0],
base_mask->ne[1],
base_mask->ne[2],
base_mask->ne[3]);

cb(final_mask, "sparsek_final_mask", il);
return final_mask;
}

ggml_tensor * llm_graph_context::maybe_apply_sparsek_mask(
ggml_tensor * base_mask,
ggml_tensor * q,
ggml_tensor * k,
int64_t n_kv,
int64_t n_rows,
int64_t n_stream,
int il) const {
GGML_UNUSED(n_kv);
GGML_UNUSED(n_rows);
GGML_UNUSED(n_stream);

// Delegate all gating (enable/topk/etc.) to build_sparsek_mask.
// If SparseK is disabled or misconfigured, it will simply return base_mask.
return build_sparsek_mask(q, k, base_mask, il);
}

// ============================================================================

ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
Expand Down Expand Up @@ -1374,8 +1523,25 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
}

cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
// SPARSEK: build final KQ mask once (union with base 0/-INF)
ggml_tensor * kq_mask_final = maybe_apply_sparsek_mask(
/*base_mask=*/kq_mask,
/*q=*/q,
/*k=*/k,
/*n_kv=*/kq_mask->ne[0],
/*n_rows=*/kq_mask->ne[1],
/*n_stream=*/kq_mask->ne[3],
/*il=*/il);

// Single flash-attn call using the final mask
cur = ggml_flash_attn_ext(
ctx0, q, k, v,
/*kq_mask=*/kq_mask_final,
kq_scale,
hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f
);

cb(cur, LLAMA_TENSOR_NAME_FATTN, il);

ggml_flash_attn_ext_add_sinks(cur, sinks);
Expand Down Expand Up @@ -1959,7 +2125,7 @@ void llm_graph_context::build_pooling(

GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");

ggml_tensor * cur;
ggml_tensor * cur = nullptr; // ensure initialized

switch (pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
Expand Down
31 changes: 31 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,37 @@ struct llm_graph_context {
// common
//

// ===[ SPARSEK: config & builders ]===========================================
// Runtime config toggles (filled in .cpp constructor; env or defaults)
bool sparsek_enable = false; // enable/disable dynamic Sparse-K
int32_t sparsek_topk = 0; // top-K per row (0 -> disabled unless window/stride applies)
int32_t sparsek_win_local = 0; // local window radius (tokens to each side)
int32_t sparsek_stride = 0; // global stride period
bool sparsek_en_local = true; // enable local window
bool sparsek_en_stride = true; // enable global stride

// Build a dynamic Sparse-K mask inside the compute graph.
// q, k: projected tensors (per-head layout consistent with current layer)
// base_mask: the pre-existing KQ mask (causal/cross/SWA) encoded as 0 / -INF
// il: layer index for cb(...) tracing
ggml_tensor * build_sparsek_mask(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * base_mask,
int il) const;

// Apply Sparse-K on top of an existing base mask when enabled.
// n_kv / n_rows / n_stream are used to validate/reshape mask layout.
ggml_tensor * maybe_apply_sparsek_mask(
ggml_tensor * base_mask,
ggml_tensor * q,
ggml_tensor * k,
int64_t n_kv,
int64_t n_rows,
int64_t n_stream,
int il) const;
// ============================================================================

ggml_tensor * build_cvec(
ggml_tensor * cur,
int il) const;
Expand Down
6 changes: 6 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ struct llama_hparams_convnext {
};

struct llama_hparams {
// === SparseK Dynamic Attention ===
bool sparsek_enable = false;
int32_t sparsek_topk = 0;
int32_t sparsek_window = 0;
int32_t sparsek_stride = 0;

bool vocab_only;
bool rope_finetuned;
bool use_par_res;
Expand Down
63 changes: 63 additions & 0 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cassert>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <limits>
#include <map>
#include <stdexcept>
Expand Down Expand Up @@ -1300,9 +1301,71 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u

data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}

}
}
}

{
// --- SparseK env (read once per process) ---
static const bool SPARSEK_ENABLE = [](){
if (const char * s = getenv("LLAMA_SPARSEK_ENABLE")) return atoi(s) != 0;
return false;
}();
static const int SPARSEK_WIN_LOCAL = [](){
if (const char * s = getenv("LLAMA_SPARSEK_WIN")) return std::max(0, atoi(s));
return 64;
}();
static const int SPARSEK_STRIDE = [](){
if (const char * s = getenv("LLAMA_SPARSEK_STRIDE")) return std::max(0, atoi(s));
return 128;
}();
static const bool SPARSEK_EN_LOCAL = [](){
if (const char * s = getenv("LLAMA_SPARSEK_ENABLE_LOCAL")) return atoi(s) != 0;
return true;
}();
static const bool SPARSEK_EN_STRIDE = [](){
if (const char * s = getenv("LLAMA_SPARSEK_ENABLE_STRIDE")) return atoi(s) != 0;
return true;
}();

if (!SPARSEK_ENABLE || (!SPARSEK_EN_LOCAL && !SPARSEK_EN_STRIDE)) {
// do nothing – keep original KQ mask
} else {
for (uint32_t s = 0; s < n_stream; ++s) {
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
const uint64_t idst =
n_kv*(/*h=*/0*n_stream*n_tps_pad + s*n_tps_pad + ii);
float * row = data + idst;
std::vector<uint8_t> allow(n_kv, 0);

if (SPARSEK_EN_LOCAL && SPARSEK_WIN_LOCAL > 0) {
const int j0 = std::max<int>(0, int(i) - SPARSEK_WIN_LOCAL);
const int j1 = std::min<int>(int(n_kv) - 1, int(i) + SPARSEK_WIN_LOCAL);
for (int j = j0; j <= j1; ++j) allow[j] = 1;
}

if (SPARSEK_EN_STRIDE && SPARSEK_STRIDE > 0) {
for (int j = int(i); j >= 0; j -= SPARSEK_STRIDE) allow[j] = 1;
if (!causal_attn) {
for (int j = int(i); j < int(n_kv); j += SPARSEK_STRIDE) allow[j] = 1;
}
}

for (int64_t j = 0; j < n_kv; ++j) {
if (!allow[j]) {
row[j] = -INFINITY;
} else if (std::isinf(row[j]) && row[j] < 0.0f) {
row[j] = 0.0f;
}
}
}
}
}
}
// ===== end SparseK =====

}

void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
Expand Down
4 changes: 3 additions & 1 deletion src/llama-model-loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ namespace GGUFMeta {
template bool llama_model_loader::get_key<float> (enum llm_kv kid, float & result, bool required);
template bool llama_model_loader::get_key<uint32_t> (enum llm_kv kid, uint32_t & result, bool required);
template bool llama_model_loader::get_key<std::string>(enum llm_kv kid, std::string & result, bool required);

template bool llama_model_loader::get_key<bool> (const std::string & key, bool & result, bool required);
template<>
bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) {
uint32_t tmp;
Expand Down Expand Up @@ -1165,3 +1165,5 @@ void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements);
}
}


6 changes: 6 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f);

ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
// === SparseK metadata (optional) ===
ml.get_key("llama.sparsek.enable", hparams.sparsek_enable, false);
ml.get_key("llama.sparsek.top_k", hparams.sparsek_topk, false);
ml.get_key("llama.sparsek.window", hparams.sparsek_window, false);
ml.get_key("llama.sparsek.stride", hparams.sparsek_stride, false);

ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);

// n_head_kv is optional, default to n_head
Expand Down
Loading