@@ -184,6 +184,12 @@ struct llama_client_slot
184
184
struct llama_sampling_params sparams;
185
185
llama_sampling_context *ctx_sampling = nullptr ;
186
186
187
+ int32_t ga_i = 0 ; // group-attention state
188
+ int32_t ga_n = 1 ;// group-attention factor
189
+ int32_t ga_w = 512 ; // group-attention width
190
+
191
+ int32_t n_past_se = 0 ; // self-extend
192
+
187
193
// multimodal
188
194
std::vector<slot_image> images;
189
195
@@ -212,7 +218,8 @@ struct llama_client_slot
212
218
sent_count = 0 ;
213
219
sent_token_probs_index = 0 ;
214
220
infill = false ;
215
-
221
+ ga_i = 0 ;
222
+ n_past_se = 0 ;
216
223
generated_token_probs.clear ();
217
224
218
225
for (slot_image & img : images)
@@ -399,9 +406,26 @@ struct llama_server_context
399
406
400
407
slot.id = i;
401
408
slot.n_ctx = n_ctx_slot;
402
- slot.reset ();
403
409
404
410
LOG_TEE (" -> Slot %i - max context: %i\n " , slot.id , n_ctx_slot);
411
+
412
+ const int ga_n = params.grp_attn_n ;
413
+ const int ga_w = params.grp_attn_w ;
414
+
415
+ if (ga_n != 1 ) {
416
+ GGML_ASSERT (ga_n > 0 && " ga_n must be positive" ); // NOLINT
417
+ GGML_ASSERT (ga_w % ga_n == 0 && " ga_w must be a multiple of ga_n" ); // NOLINT
418
+ // GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
419
+ // GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
420
+ LOG_TEE (" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n " , slot.id , ga_n, ga_w);
421
+ }
422
+
423
+ slot.ga_i = 0 ;
424
+ slot.ga_n = ga_n;
425
+ slot.ga_w = ga_w;
426
+
427
+ slot.reset ();
428
+
405
429
slots.push_back (slot);
406
430
}
407
431
@@ -1349,32 +1373,35 @@ struct llama_server_context
1349
1373
1350
1374
for (llama_client_slot &slot : slots)
1351
1375
{
1352
- if (slot.is_processing () && slot. cache_tokens . size () >= ( size_t ) slot. n_ctx )
1376
+ if (slot.ga_n == 1 )
1353
1377
{
1354
- // Shift context
1355
- const int n_left = slot.n_past - slot.params .n_keep - 1 ;
1356
- const int n_discard = n_left / 2 ;
1378
+ if (slot.is_processing () && slot.cache_tokens .size () >= (size_t ) slot.n_ctx )
1379
+ {
1380
+ // Shift context
1381
+ const int n_left = slot.n_past - slot.params .n_keep - 1 ;
1382
+ const int n_discard = n_left / 2 ;
1357
1383
1358
- LOG_TEE (" slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n " , slot.id , slot.params .n_keep , n_left, n_discard);
1359
- llama_kv_cache_seq_rm (ctx, slot.id , slot.params .n_keep + 1 , slot.params .n_keep + n_discard + 1 );
1360
- llama_kv_cache_seq_shift (ctx, slot.id , slot.params .n_keep + 1 + n_discard, slot.n_past , -n_discard);
1384
+ LOG_TEE (" slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n " , slot.id , slot.params .n_keep , n_left, n_discard);
1385
+ llama_kv_cache_seq_rm (ctx, slot.id , slot.params .n_keep + 1 , slot.params .n_keep + n_discard + 1 );
1386
+ llama_kv_cache_seq_shift (ctx, slot.id , slot.params .n_keep + 1 + n_discard, slot.n_past , -n_discard);
1361
1387
1362
- for (size_t i = slot.params .n_keep + 1 + n_discard; i < slot.cache_tokens .size (); i++)
1363
- {
1364
- slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1365
- }
1388
+ for (size_t i = slot.params .n_keep + 1 + n_discard; i < slot.cache_tokens .size (); i++)
1389
+ {
1390
+ slot.cache_tokens [i - n_discard] = slot.cache_tokens [i];
1391
+ }
1366
1392
1367
- slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1393
+ slot.cache_tokens .resize (slot.cache_tokens .size () - n_discard);
1368
1394
1369
- slot.n_past -= n_discard;
1395
+ slot.n_past -= n_discard;
1370
1396
1371
- slot.truncated = true ;
1397
+ slot.truncated = true ;
1372
1398
1373
- LOG_VERBOSE (" context shift" , {
1374
- {" n_ctx" , n_ctx},
1375
- {" n_keep" , params.n_keep },
1376
- {" n_left" , n_left},
1377
- });
1399
+ LOG_VERBOSE (" context shift" , {
1400
+ { " n_ctx" , n_ctx },
1401
+ { " n_keep" , params.n_keep },
1402
+ { " n_left" , n_left },
1403
+ });
1404
+ }
1378
1405
}
1379
1406
}
1380
1407
@@ -1401,7 +1428,8 @@ struct llama_server_context
1401
1428
1402
1429
slot.i_batch = batch.n_tokens ;
1403
1430
1404
- llama_batch_add (batch, slot.sampled , system_tokens.size () + slot.n_past , { slot.id }, true );
1431
+ const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1432
+ llama_batch_add (batch, slot.sampled , system_tokens.size () + slot_npast, { slot.id }, true );
1405
1433
1406
1434
slot.n_past += 1 ;
1407
1435
}
@@ -1499,6 +1527,8 @@ struct llama_server_context
1499
1527
llama_sampling_reset (slot.ctx_sampling );
1500
1528
1501
1529
slot.n_past = 0 ;
1530
+ slot.n_past_se = 0 ;
1531
+ slot.ga_i = 0 ;
1502
1532
slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
1503
1533
}
1504
1534
else
@@ -1512,6 +1542,25 @@ struct llama_server_context
1512
1542
slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1513
1543
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1514
1544
1545
+ if (slot.ga_n != 1 )
1546
+ {
1547
+ int ga_i = 0 ;
1548
+ int32_t ga_n = slot.ga_n ;
1549
+ int32_t ga_w = slot.ga_w ;
1550
+ int32_t slot_npast = 0 ;
1551
+ for (int k = 0 ; k < slot.n_past ; ++k)
1552
+ {
1553
+ while (slot_npast >= ga_i + ga_w) {
1554
+ const int bd = (ga_w/ga_n)*(ga_n - 1 );
1555
+ slot_npast -= bd;
1556
+ ga_i += ga_w/ga_n;
1557
+ }
1558
+ slot_npast++;
1559
+ }
1560
+ slot.n_past_se = slot_npast;
1561
+ slot.ga_i = ga_i;
1562
+ }
1563
+
1515
1564
LOG_TEE (" slot %d : in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1516
1565
}
1517
1566
@@ -1526,6 +1575,10 @@ struct llama_server_context
1526
1575
// we have to evaluate at least 1 token to generate logits.
1527
1576
LOG_TEE (" slot %d : we have to evaluate at least 1 token to generate logits\n " , slot.id );
1528
1577
slot.n_past --;
1578
+ if (slot.ga_i > 0 )
1579
+ {
1580
+ slot.n_past_se --;
1581
+ }
1529
1582
}
1530
1583
1531
1584
LOG_VERBOSE (" prompt ingested" , {
@@ -1538,9 +1591,22 @@ struct llama_server_context
1538
1591
1539
1592
// process the prefix of first image
1540
1593
std::vector<llama_token> prefix_tokens = has_images ? tokenize (slot.images [0 ].prefix_prompt , add_bos_token) : prompt_tokens;
1594
+ int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
1595
+ int ga_i = slot.ga_i ;
1596
+ int32_t ga_n = slot.ga_n ;
1597
+ int32_t ga_w = slot.ga_w ;
1541
1598
for (; slot.n_past < (int ) prefix_tokens.size (); ++slot.n_past )
1542
1599
{
1543
- llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot.n_past , { slot.id }, false );
1600
+ if (slot.ga_n != 1 )
1601
+ {
1602
+ while (slot_npast >= ga_i + ga_w) {
1603
+ const int bd = (ga_w/ga_n)*(ga_n - 1 );
1604
+ slot_npast -= bd;
1605
+ ga_i += ga_w/ga_n;
1606
+ }
1607
+ }
1608
+ llama_batch_add (batch, prefix_tokens[slot.n_past ], system_tokens.size () + slot_npast, {slot.id }, false );
1609
+ slot_npast += 1 ;
1544
1610
}
1545
1611
1546
1612
if (has_images && !ingest_images (slot, n_batch))
@@ -1570,6 +1636,36 @@ struct llama_server_context
1570
1636
for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch)
1571
1637
{
1572
1638
const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
1639
+
1640
+ for (auto & slot : slots)
1641
+ {
1642
+ if (slot.ga_n != 1 )
1643
+ {
1644
+ // context extension via Self-Extend
1645
+ while (slot.n_past_se >= slot.ga_i + slot.ga_w )
1646
+ {
1647
+ const int ib = (slot.ga_n * slot.ga_i ) / slot.ga_w ;
1648
+ const int bd = (slot.ga_w / slot.ga_n ) * (slot.ga_n - 1 );
1649
+ const int dd = (slot.ga_w / slot.ga_n ) - ib * bd - slot.ga_w ;
1650
+
1651
+ LOG_TEE (" \n " );
1652
+ LOG_TEE (" shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i , slot.n_past_se , ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
1653
+ LOG_TEE (" div: [%6d, %6d] / %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w , slot.ga_n , (slot.ga_i + ib * bd) / slot.ga_n , (slot.ga_i + ib * bd + slot.ga_w ) / slot.ga_n );
1654
+ LOG_TEE (" shift: [%6d, %6d] + %6d -> [%6d, %6d]\n " , slot.ga_i + ib * bd + slot.ga_w , slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
1655
+
1656
+ llama_kv_cache_seq_shift (ctx, slot.id , slot.ga_i , slot.n_past_se , ib * bd);
1657
+ llama_kv_cache_seq_div (ctx, slot.id , slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w ,slot.ga_n );
1658
+ llama_kv_cache_seq_shift (ctx, slot.id , slot.ga_i + ib * bd + slot.ga_w ,slot.n_past_se + ib * bd, dd);
1659
+
1660
+ slot.n_past_se -= bd;
1661
+
1662
+ slot.ga_i += slot.ga_w / slot.ga_n ;
1663
+
1664
+ LOG_TEE (" \n n_past_old = %d, n_past = %d, ga_i = %d\n\n " , slot.n_past_se + bd, slot.n_past_se , slot.ga_i );
1665
+ }
1666
+ slot.n_past_se += n_tokens;
1667
+ }
1668
+ }
1573
1669
llama_batch batch_view =
1574
1670
{
1575
1671
n_tokens,
@@ -1583,6 +1679,7 @@ struct llama_server_context
1583
1679
};
1584
1680
1585
1681
const int ret = llama_decode (ctx, batch_view);
1682
+
1586
1683
if (ret != 0 )
1587
1684
{
1588
1685
if (n_batch == 1 || ret < 0 )
@@ -1728,6 +1825,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
1728
1825
printf (" --override-kv KEY=TYPE:VALUE\n " );
1729
1826
printf (" advanced option to override model metadata by key. may be specified multiple times.\n " );
1730
1827
printf (" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n " );
1828
+ printf (" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`" );
1829
+ printf (" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`" );
1731
1830
printf (" \n " );
1732
1831
}
1733
1832
@@ -1913,6 +2012,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
1913
2012
}
1914
2013
params.n_threads = std::stoi (argv[i]);
1915
2014
}
2015
+ else if (arg == " --grp-attn-n" || arg == " -gan" )
2016
+ {
2017
+ if (++i >= argc) {
2018
+ invalid_param = true ;
2019
+ break ;
2020
+ }
2021
+
2022
+ params.grp_attn_n = std::stoi (argv[i]);
2023
+ }
2024
+ else if (arg == " --grp-attn-w" || arg == " -gaw" )
2025
+ {
2026
+ if (++i >= argc)
2027
+ {
2028
+ invalid_param = true ;
2029
+ break ;
2030
+ }
2031
+
2032
+ params.grp_attn_w = std::stoi (argv[i]);
2033
+ }
1916
2034
else if (arg == " --threads-batch" || arg == " -tb" )
1917
2035
{
1918
2036
if (++i >= argc)
0 commit comments