Skip to content

Commit 7c5f85a

Browse files
author
Gitty Burstein
committed
Integrate SparseK attention mask support and cleanup Co-authored-by: Yael <[email protected]>
Co-authored-by: Gitty <[email protected]>
1 parent 3d252a1 commit 7c5f85a

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

src/llama-kv-cache.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,75 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
13001300

13011301
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
13021302
}
1303+
// ===== SparseK (minimal diff): local window + stride =====
1304+
// Control parameters via environment variables (no CLI / API changes):
1305+
auto env_i = [](const char *name, int def)->int {
1306+
if (const char *s = std::getenv(name)) { return std::max(0, std::atoi(s)); }
1307+
return def;
1308+
};
1309+
auto env_b = [](const char *name, bool def)->bool {
1310+
if (const char *s = std::getenv(name)) { return std::atoi(s) != 0; }
1311+
return def;
1312+
};
1313+
1314+
// Enable SparseK (0=off, 1=on) + parameters
1315+
const bool enable_sparsek = env_b("LLAMA_SPARSEK_ENABLE", false);
1316+
const int win_local = env_i("LLAMA_SPARSEK_WIN", 64); // half-window around i
1317+
const int stride_g = env_i("LLAMA_SPARSEK_STRIDE", 128); // global stride step
1318+
const bool en_local = env_b("LLAMA_SPARSEK_ENABLE_LOCAL", true);
1319+
const bool en_stride = env_b("LLAMA_SPARSEK_ENABLE_STRIDE", true);
1320+
1321+
// Apply SparseK sparsity to the already-built mask.
1322+
// Everything outside the SparseK policy will be forced to -INF.
1323+
if (enable_sparsek && (en_local || en_stride)) {
1324+
for (uint32_t s = 0; s < n_stream; ++s) {
1325+
for (uint32_t ii = 0; ii < n_tps; ++ii) {
1326+
const uint32_t i = s*n_tps + ii;
1327+
1328+
// Row base index in the flat mask tensor
1329+
const uint64_t idst = n_kv*(/*h=*/0*n_stream*n_tps_pad + s*n_tps_pad + ii);
1330+
float * row = data + idst;
1331+
1332+
// Build "allow" mask: 1 = allowed, 0 = pruned
1333+
std::vector<uint8_t> allow(n_kv, 0);
1334+
1335+
// 1) Local window
1336+
if (en_local && win_local > 0) {
1337+
const int j0 = std::max<int>(0, int(i) - win_local);
1338+
const int j1 = std::min<int>(int(n_kv)-1,int(i) + win_local);
1339+
for (int j = j0; j <= j1; ++j) allow[j] = 1;
1340+
}
1341+
1342+
// 2) Global stride: backward only for causal; both directions if non-causal
1343+
if (en_stride && stride_g > 0) {
1344+
for (int j = int(i); j >= 0; j -= stride_g) allow[j] = 1;
1345+
if (!causal_attn) {
1346+
for (int j = int(i); j < int(n_kv); j += stride_g) allow[j] = 1;
1347+
}
1348+
}
1349+
1350+
// 3) Apply pruning: outside "allow" → -INF; inside → keep existing or set 0.0f
1351+
bool any_allowed = false;
1352+
for (int64_t j = 0; j < n_kv; ++j) {
1353+
if (allow[j]) {
1354+
if (std::isinf(row[j]) && row[j] < 0.0f) {
1355+
row[j] = 0.0f; // release from -INF if previously forbidden
1356+
}
1357+
any_allowed = true;
1358+
} else {
1359+
row[j] = -INFINITY; // enforce sparsity
1360+
}
1361+
}
1362+
1363+
// Safety: make sure the row is not completely empty (avoid NaN in Softmax)
1364+
if (!any_allowed) {
1365+
const int64_t jj = std::min<int64_t>(i, n_kv - 1);
1366+
row[jj] = 0.0f;
1367+
}
1368+
}
1369+
}
1370+
}
1371+
// ===== end SparseK minimal =====
13031372
}
13041373
}
13051374
}

0 commit comments

Comments
 (0)