Skip to content

Commit d7eca25

Browse files
committed
context shift fixed
1 parent 2d9f11d commit d7eca25

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

examples/server/server.cpp

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,8 @@ struct llama_server_context
757757
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
758758
slot.sent_count += result.text_to_send.size();
759759
// add the token to slot queue and cache
760-
slot.addTokenString(result);
761760
}
761+
slot.addTokenString(result);
762762
if (slot.multibyte_pending > 0)
763763
{
764764
slot.multibyte_pending -= token_str.size();
@@ -925,8 +925,8 @@ struct llama_server_context
925925
}
926926

927927
// context shift takes effect only when there is a single slot
928-
if(slots.size() == 1) {
929-
llama_client_slot slot = slots[0];
928+
if(params.n_parallel == 1) {
929+
llama_client_slot &slot = slots[0];
930930
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx)
931931
{
932932
// Shift context
@@ -1028,22 +1028,16 @@ struct llama_server_context
10281028

10291029
slot.num_prompt_tokens = prompt_tokens.size();
10301030

1031-
slot.n_past = slot.params.cache_prompt ? common_part(slot.cache_tokens, prompt_tokens) : 0;
1032-
1033-
slot.cache_tokens = prompt_tokens;
1034-
1035-
if (slot.n_past == slot.num_prompt_tokens) {
1036-
// we have to evaluate at least 1 token to generate logits.
1037-
printf("we have to evaluate at least 1 token to generate logits\n");
1038-
slot.n_past--;
1039-
}
1040-
1041-
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
1042-
1043-
if(!slot.params.cache_prompt) {
1031+
if(!slot.params.cache_prompt) {
10441032
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
1033+
slot.n_past = 0;
1034+
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
10451035
} else {
1046-
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
1036+
if (params.n_keep < 0 && params.n_parallel == 1)
1037+
{
1038+
params.n_keep = (int)slot.num_prompt_tokens;
1039+
}
1040+
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
10471041
//if input prompt is too big, truncate like normal
10481042
if (slot.num_prompt_tokens >= (size_t)n_ctx)
10491043
{
@@ -1059,14 +1053,26 @@ struct llama_server_context
10591053
});
10601054
slot.truncated = true;
10611055
prompt_tokens = new_tokens;
1056+
slot.num_prompt_tokens = prompt_tokens.size();
10621057
}
10631058
const size_t ps = slot.num_prompt_tokens;
10641059
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
10651060
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
1061+
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1062+
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
1063+
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
10661064
}
10671065

10681066
llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1);
10691067

1068+
slot.cache_tokens = prompt_tokens;
1069+
1070+
if (slot.n_past == slot.num_prompt_tokens) {
1071+
// we have to evaluate at least 1 token to generate logits.
1072+
printf("we have to evaluate at least 1 token to generate logits\n");
1073+
slot.n_past--;
1074+
}
1075+
10701076
LOG_VERBOSE("prompt ingested", {
10711077
{"n_past", slot.n_past},
10721078
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
@@ -1185,7 +1191,7 @@ struct llama_server_context
11851191
}
11861192
}
11871193

1188-
if(kv_cache_free < 0) {
1194+
if(kv_cache_free < 0 && params.n_parallel > 1) {
11891195
LOG_TEE("\nError: kv cache is full, increase context size.");
11901196
return false;
11911197
}
@@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
15811587
}
15821588
}
15831589

1590+
static void slot_print_timings(struct llama_client_slot * slot) {
1591+
LOG_TEE("\n");
1592+
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
1593+
__func__, slot->t_prompt_processing, slot->num_prompt_tokens_processed, slot->t_prompt_processing / slot->num_prompt_tokens_processed, 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed);
1594+
LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
1595+
__func__, slot->t_token_generation, slot->n_decoded, slot->t_token_generation / slot->n_decoded, 1e3 / slot->t_token_generation * slot->n_decoded);
1596+
LOG_TEE("%s: total time = %10.2f ms\n", __func__, slot->t_prompt_processing + slot->t_token_generation);
1597+
}
1598+
15841599
static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot)
15851600
{
15861601
const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx));
@@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client
16061621
{"penalize_nl", slot->sparams.penalize_nl},
16071622
{"stop", slot->params.antiprompt},
16081623
{"n_predict", slot->params.n_predict},
1609-
// {"n_keep", slot.params.n_keep},
1624+
{"n_keep", llama.params.n_keep},
16101625
{"ignore_eos", ignore_eos},
16111626
{"stream", slot->params.stream},
16121627
{"logit_bias", slot->sparams.logit_bias},
@@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
17301745
slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
17311746
slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
17321747
slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
1733-
llama.params.n_keep = json_value(body, "n_keep", -1);
1748+
llama.params.n_keep = json_value(body, "n_keep", 0);
17341749
slot->params.seed = json_value(body, "seed", default_params.seed);
17351750
slot->params.grammar = json_value(body, "grammar", default_params.grammar);
17361751
slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
@@ -2089,6 +2104,7 @@ int main(int argc, char **argv)
20892104
}
20902105

20912106
const json data = format_final_response(llama, slot, completion_text, probs);
2107+
slot_print_timings(slot);
20922108
slot->release();
20932109
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
20942110
"application/json");
@@ -2131,6 +2147,7 @@ int main(int argc, char **argv)
21312147
slot->generated_token_probs.begin(),
21322148
slot->generated_token_probs.begin() + sent_token_probs_index)
21332149
);
2150+
slot_print_timings(slot);
21342151
const std::string str =
21352152
"data: " +
21362153
data.dump(-1, ' ', false, json::error_handler_t::replace) +
@@ -2197,6 +2214,7 @@ int main(int argc, char **argv)
21972214
}
21982215

21992216
const json data = format_final_response(llama, slot, completion_text, probs);
2217+
slot_print_timings(slot);
22002218
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
22012219
"application/json");
22022220
} else {
@@ -2238,6 +2256,7 @@ int main(int argc, char **argv)
22382256
slot->generated_token_probs.begin(),
22392257
slot->generated_token_probs.begin() + sent_token_probs_index)
22402258
);
2259+
slot_print_timings(slot);
22412260
const std::string str =
22422261
"data: " +
22432262
data.dump(-1, ' ', false, json::error_handler_t::replace) +

0 commit comments

Comments
 (0)