@@ -1300,6 +1300,75 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
13001300
13011301 data[idst + j] = hparams.use_alibi ? -std::abs (p0 - p1) : 0 .0f ;
13021302 }
1303+ // ===== SparseK (minimal diff): local window + stride =====
1304+ // Control parameters via environment variables (no CLI / API changes):
1305+ auto env_i = [](const char *name, int def)->int {
1306+ if (const char *s = std::getenv (name)) { return std::max (0 , std::atoi (s)); }
1307+ return def;
1308+ };
1309+ auto env_b = [](const char *name, bool def)->bool {
1310+ if (const char *s = std::getenv (name)) { return std::atoi (s) != 0 ; }
1311+ return def;
1312+ };
1313+
1314+ // Enable SparseK (0=off, 1=on) + parameters
1315+ const bool enable_sparsek = env_b (" LLAMA_SPARSEK_ENABLE" , false );
1316+ const int win_local = env_i (" LLAMA_SPARSEK_WIN" , 64 ); // half-window around i
1317+ const int stride_g = env_i (" LLAMA_SPARSEK_STRIDE" , 128 ); // global stride step
1318+ const bool en_local = env_b (" LLAMA_SPARSEK_ENABLE_LOCAL" , true );
1319+ const bool en_stride = env_b (" LLAMA_SPARSEK_ENABLE_STRIDE" , true );
1320+
1321+ // Apply SparseK sparsity to the already-built mask.
1322+ // Everything outside the SparseK policy will be forced to -INF.
1323+ if (enable_sparsek && (en_local || en_stride)) {
1324+ for (uint32_t s = 0 ; s < n_stream; ++s) {
1325+ for (uint32_t ii = 0 ; ii < n_tps; ++ii) {
1326+ const uint32_t i = s*n_tps + ii;
1327+
1328+ // Row base index in the flat mask tensor
1329+ const uint64_t idst = n_kv*(/* h=*/ 0 *n_stream*n_tps_pad + s*n_tps_pad + ii);
1330+ float * row = data + idst;
1331+
1332+ // Build "allow" mask: 1 = allowed, 0 = pruned
1333+ std::vector<uint8_t > allow (n_kv, 0 );
1334+
1335+ // 1) Local window
1336+ if (en_local && win_local > 0 ) {
1337+ const int j0 = std::max<int >(0 , int (i) - win_local);
1338+ const int j1 = std::min<int >(int (n_kv)-1 ,int (i) + win_local);
1339+ for (int j = j0; j <= j1; ++j) allow[j] = 1 ;
1340+ }
1341+
1342+ // 2) Global stride: backward only for causal; both directions if non-causal
1343+ if (en_stride && stride_g > 0 ) {
1344+ for (int j = int (i); j >= 0 ; j -= stride_g) allow[j] = 1 ;
1345+ if (!causal_attn) {
1346+ for (int j = int (i); j < int (n_kv); j += stride_g) allow[j] = 1 ;
1347+ }
1348+ }
1349+
1350+ // 3) Apply pruning: outside "allow" → -INF; inside → keep existing or set 0.0f
1351+ bool any_allowed = false ;
1352+ for (int64_t j = 0 ; j < n_kv; ++j) {
1353+ if (allow[j]) {
1354+ if (std::isinf (row[j]) && row[j] < 0 .0f ) {
1355+ row[j] = 0 .0f ; // release from -INF if previously forbidden
1356+ }
1357+ any_allowed = true ;
1358+ } else {
1359+ row[j] = -INFINITY; // enforce sparsity
1360+ }
1361+ }
1362+
1363+ // Safety: make sure the row is not completely empty (avoid NaN in Softmax)
1364+ if (!any_allowed) {
1365+ const int64_t jj = std::min<int64_t >(i, n_kv - 1 );
1366+ row[jj] = 0 .0f ;
1367+ }
1368+ }
1369+ }
1370+ }
1371+ // ===== end SparseK minimal =====
13031372 }
13041373 }
13051374 }
0 commit comments