Skip to content

Commit af728ca

Browse files
Maximilian-Winterggerganov
authored andcommitted
server : add self-extend support (ggml-org#5104)
* Ported self extension to server example * Update server.cpp * Fixed prompt caching without self extend * Update server.cpp * Added description to server readme. * Update server.cpp * Update server.cpp * Update server.cpp * Update server.cpp * Update README.md * Changed descriptions * server : formatting * Update examples/server/server.cpp Co-authored-by: Georgi Gerganov <[email protected]> * Update examples/server/server.cpp Co-authored-by: Georgi Gerganov <[email protected]> * Update server.cpp * Update server.cpp --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 0888dc9 commit af728ca

File tree

2 files changed

+143
-24
lines changed

2 files changed

+143
-24
lines changed

examples/server/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ Command line options:
3030
- `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled)
3131
- `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime)
3232
- `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA.
33-
33+
- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`
34+
- `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`
3435
## Build
3536

3637
server is build alongside everything else from the root of the project

examples/server/server.cpp

Lines changed: 141 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ struct llama_client_slot
184184
struct llama_sampling_params sparams;
185185
llama_sampling_context *ctx_sampling = nullptr;
186186

187+
int32_t ga_i = 0; // group-attention state
188+
int32_t ga_n = 1;// group-attention factor
189+
int32_t ga_w = 512; // group-attention width
190+
191+
int32_t n_past_se = 0; // self-extend
192+
187193
// multimodal
188194
std::vector<slot_image> images;
189195

@@ -212,7 +218,8 @@ struct llama_client_slot
212218
sent_count = 0;
213219
sent_token_probs_index = 0;
214220
infill = false;
215-
221+
ga_i = 0;
222+
n_past_se = 0;
216223
generated_token_probs.clear();
217224

218225
for (slot_image & img : images)
@@ -399,9 +406,26 @@ struct llama_server_context
399406

400407
slot.id = i;
401408
slot.n_ctx = n_ctx_slot;
402-
slot.reset();
403409

404410
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
411+
412+
const int ga_n = params.grp_attn_n;
413+
const int ga_w = params.grp_attn_w;
414+
415+
if (ga_n != 1) {
416+
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
417+
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
418+
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
419+
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
420+
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
421+
}
422+
423+
slot.ga_i = 0;
424+
slot.ga_n = ga_n;
425+
slot.ga_w = ga_w;
426+
427+
slot.reset();
428+
405429
slots.push_back(slot);
406430
}
407431

@@ -1349,32 +1373,35 @@ struct llama_server_context
13491373

13501374
for (llama_client_slot &slot : slots)
13511375
{
1352-
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
1376+
if (slot.ga_n == 1)
13531377
{
1354-
// Shift context
1355-
const int n_left = slot.n_past - slot.params.n_keep - 1;
1356-
const int n_discard = n_left / 2;
1378+
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
1379+
{
1380+
// Shift context
1381+
const int n_left = slot.n_past - slot.params.n_keep - 1;
1382+
const int n_discard = n_left / 2;
13571383

1358-
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
1359-
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
1360-
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
1384+
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
1385+
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
1386+
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
13611387

1362-
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
1363-
{
1364-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1365-
}
1388+
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
1389+
{
1390+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1391+
}
13661392

1367-
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1393+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
13681394

1369-
slot.n_past -= n_discard;
1395+
slot.n_past -= n_discard;
13701396

1371-
slot.truncated = true;
1397+
slot.truncated = true;
13721398

1373-
LOG_VERBOSE("context shift", {
1374-
{"n_ctx", n_ctx},
1375-
{"n_keep", params.n_keep},
1376-
{"n_left", n_left},
1377-
});
1399+
LOG_VERBOSE("context shift", {
1400+
{ "n_ctx", n_ctx },
1401+
{ "n_keep", params.n_keep },
1402+
{ "n_left", n_left },
1403+
});
1404+
}
13781405
}
13791406
}
13801407

@@ -1401,7 +1428,8 @@ struct llama_server_context
14011428

14021429
slot.i_batch = batch.n_tokens;
14031430

1404-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
1431+
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1432+
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
14051433

14061434
slot.n_past += 1;
14071435
}
@@ -1499,6 +1527,8 @@ struct llama_server_context
14991527
llama_sampling_reset(slot.ctx_sampling);
15001528

15011529
slot.n_past = 0;
1530+
slot.n_past_se = 0;
1531+
slot.ga_i = 0;
15021532
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
15031533
}
15041534
else
@@ -1512,6 +1542,25 @@ struct llama_server_context
15121542
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
15131543
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
15141544

1545+
if (slot.ga_n != 1)
1546+
{
1547+
int ga_i = 0;
1548+
int32_t ga_n = slot.ga_n;
1549+
int32_t ga_w = slot.ga_w;
1550+
int32_t slot_npast = 0;
1551+
for (int k = 0; k < slot.n_past; ++k)
1552+
{
1553+
while (slot_npast >= ga_i + ga_w) {
1554+
const int bd = (ga_w/ga_n)*(ga_n - 1);
1555+
slot_npast -= bd;
1556+
ga_i += ga_w/ga_n;
1557+
}
1558+
slot_npast++;
1559+
}
1560+
slot.n_past_se = slot_npast;
1561+
slot.ga_i = ga_i;
1562+
}
1563+
15151564
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
15161565
}
15171566

@@ -1526,6 +1575,10 @@ struct llama_server_context
15261575
// we have to evaluate at least 1 token to generate logits.
15271576
LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id);
15281577
slot.n_past--;
1578+
if (slot.ga_i > 0)
1579+
{
1580+
slot.n_past_se--;
1581+
}
15291582
}
15301583

15311584
LOG_VERBOSE("prompt ingested", {
@@ -1538,9 +1591,22 @@ struct llama_server_context
15381591

15391592
// process the prefix of first image
15401593
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
1594+
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1595+
int ga_i = slot.ga_i;
1596+
int32_t ga_n = slot.ga_n;
1597+
int32_t ga_w = slot.ga_w;
15411598
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
15421599
{
1543-
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
1600+
if (slot.ga_n != 1)
1601+
{
1602+
while (slot_npast >= ga_i + ga_w) {
1603+
const int bd = (ga_w/ga_n)*(ga_n - 1);
1604+
slot_npast -= bd;
1605+
ga_i += ga_w/ga_n;
1606+
}
1607+
}
1608+
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
1609+
slot_npast += 1;
15441610
}
15451611

15461612
if (has_images && !ingest_images(slot, n_batch))
@@ -1570,6 +1636,36 @@ struct llama_server_context
15701636
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
15711637
{
15721638
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
1639+
1640+
for (auto & slot : slots)
1641+
{
1642+
if (slot.ga_n != 1)
1643+
{
1644+
// context extension via Self-Extend
1645+
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
1646+
{
1647+
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
1648+
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
1649+
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
1650+
1651+
LOG_TEE("\n");
1652+
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
1653+
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
1654+
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
1655+
1656+
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
1657+
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
1658+
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
1659+
1660+
slot.n_past_se -= bd;
1661+
1662+
slot.ga_i += slot.ga_w / slot.ga_n;
1663+
1664+
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
1665+
}
1666+
slot.n_past_se += n_tokens;
1667+
}
1668+
}
15731669
llama_batch batch_view =
15741670
{
15751671
n_tokens,
@@ -1583,6 +1679,7 @@ struct llama_server_context
15831679
};
15841680

15851681
const int ret = llama_decode(ctx, batch_view);
1682+
15861683
if (ret != 0)
15871684
{
15881685
if (n_batch == 1 || ret < 0)
@@ -1728,6 +1825,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
17281825
printf(" --override-kv KEY=TYPE:VALUE\n");
17291826
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
17301827
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
1828+
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
1829+
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
17311830
printf("\n");
17321831
}
17331832

@@ -1913,6 +2012,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
19132012
}
19142013
params.n_threads = std::stoi(argv[i]);
19152014
}
2015+
else if (arg == "--grp-attn-n" || arg == "-gan")
2016+
{
2017+
if (++i >= argc) {
2018+
invalid_param = true;
2019+
break;
2020+
}
2021+
2022+
params.grp_attn_n = std::stoi(argv[i]);
2023+
}
2024+
else if (arg == "--grp-attn-w" || arg == "-gaw")
2025+
{
2026+
if (++i >= argc)
2027+
{
2028+
invalid_param = true;
2029+
break;
2030+
}
2031+
2032+
params.grp_attn_w = std::stoi(argv[i]);
2033+
}
19162034
else if (arg == "--threads-batch" || arg == "-tb")
19172035
{
19182036
if (++i >= argc)

0 commit comments

Comments
 (0)