Skip to content

Commit 8772d3e

Browse files
committed
server : take system_tokens into account
1 parent 51bb7f0 commit 8772d3e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

examples/server/server.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ struct llama_server_context
12251225
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
12261226
for (int i = 0; i < (int) append_tokens.size(); ++i)
12271227
{
1228-
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
1228+
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
12291229
slot.n_past += 1;
12301230
}
12311231
}
@@ -1376,12 +1376,12 @@ struct llama_server_context
13761376
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
13771377
{
13781378
// Shift context
1379-
const int n_left = slot.n_past - slot.params.n_keep - 1;
1379+
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
13801380
const int n_discard = n_left / 2;
13811381

13821382
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
13831383
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
1384-
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
1384+
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);
13851385

13861386
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
13871387
{
@@ -1426,6 +1426,8 @@ struct llama_server_context
14261426

14271427
slot.i_batch = batch.n_tokens;
14281428

1429+
// TODO: we always have to take into account the "system_tokens"
1430+
// this is not great and needs to be improved somehow
14291431
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true);
14301432

14311433
slot.n_past += 1;
@@ -1478,8 +1480,8 @@ struct llama_server_context
14781480

14791481
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
14801482
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
1481-
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
1482-
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
1483+
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
1484+
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
14831485
prefix_tokens.push_back(llama_token_middle(model));
14841486
prompt_tokens = prefix_tokens;
14851487
}

0 commit comments

Comments
 (0)