@@ -757,8 +757,8 @@ struct llama_server_context
757
757
result.text_to_send = slot.generated_text .substr (pos, std::string::npos);
758
758
slot.sent_count += result.text_to_send .size ();
759
759
// add the token to slot queue and cache
760
- slot.addTokenString (result);
761
760
}
761
+ slot.addTokenString (result);
762
762
if (slot.multibyte_pending > 0 )
763
763
{
764
764
slot.multibyte_pending -= token_str.size ();
@@ -925,8 +925,8 @@ struct llama_server_context
925
925
}
926
926
927
927
// context shift takes effect only when there is a single slot
928
- if (slots. size () == 1 ) {
929
- llama_client_slot slot = slots[0 ];
928
+ if (params. n_parallel == 1 ) {
929
+ llama_client_slot & slot = slots[0 ];
930
930
if (slot.isProcessing () && slot.cache_tokens .size () >= (size_t )n_ctx)
931
931
{
932
932
// Shift context
@@ -1028,22 +1028,16 @@ struct llama_server_context
1028
1028
1029
1029
slot.num_prompt_tokens = prompt_tokens.size ();
1030
1030
1031
- slot.n_past = slot.params .cache_prompt ? common_part (slot.cache_tokens , prompt_tokens) : 0 ;
1032
-
1033
- slot.cache_tokens = prompt_tokens;
1034
-
1035
- if (slot.n_past == slot.num_prompt_tokens ) {
1036
- // we have to evaluate at least 1 token to generate logits.
1037
- printf (" we have to evaluate at least 1 token to generate logits\n " );
1038
- slot.n_past --;
1039
- }
1040
-
1041
- slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1042
-
1043
- if (!slot.params .cache_prompt ) {
1031
+ if (!slot.params .cache_prompt ) {
1044
1032
std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
1033
+ slot.n_past = 0 ;
1034
+ slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
1045
1035
} else {
1046
- LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1036
+ if (params.n_keep < 0 && params.n_parallel == 1 )
1037
+ {
1038
+ params.n_keep = (int )slot.num_prompt_tokens ;
1039
+ }
1040
+ params.n_keep = std::min (params.n_ctx - 4 , params.n_keep );
1047
1041
// if input prompt is too big, truncate like normal
1048
1042
if (slot.num_prompt_tokens >= (size_t )n_ctx)
1049
1043
{
@@ -1059,14 +1053,26 @@ struct llama_server_context
1059
1053
});
1060
1054
slot.truncated = true ;
1061
1055
prompt_tokens = new_tokens;
1056
+ slot.num_prompt_tokens = prompt_tokens.size ();
1062
1057
}
1063
1058
const size_t ps = slot.num_prompt_tokens ;
1064
1059
std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end () - ps, 0 );
1065
1060
std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.last_n_tokens .end () - ps);
1061
+ slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1062
+ slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1063
+ LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
1066
1064
}
1067
1065
1068
1066
llama_kv_cache_seq_rm (ctx, slot.id , num_tokens_system + slot.n_past , -1 );
1069
1067
1068
+ slot.cache_tokens = prompt_tokens;
1069
+
1070
+ if (slot.n_past == slot.num_prompt_tokens ) {
1071
+ // we have to evaluate at least 1 token to generate logits.
1072
+ printf (" we have to evaluate at least 1 token to generate logits\n " );
1073
+ slot.n_past --;
1074
+ }
1075
+
1070
1076
LOG_VERBOSE (" prompt ingested" , {
1071
1077
{" n_past" , slot.n_past },
1072
1078
{" cached" , tokens_to_str (ctx, slot.cache_tokens .cbegin (), slot.cache_tokens .cbegin () + slot.n_past )},
@@ -1185,7 +1191,7 @@ struct llama_server_context
1185
1191
}
1186
1192
}
1187
1193
1188
- if (kv_cache_free < 0 ) {
1194
+ if (kv_cache_free < 0 && params. n_parallel > 1 ) {
1189
1195
LOG_TEE (" \n Error: kv cache is full, increase context size." );
1190
1196
return false ;
1191
1197
}
@@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
1581
1587
}
1582
1588
}
1583
1589
1590
+ static void slot_print_timings (struct llama_client_slot * slot) {
1591
+ LOG_TEE (" \n " );
1592
+ LOG_TEE (" %s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n " ,
1593
+ __func__, slot->t_prompt_processing , slot->num_prompt_tokens_processed , slot->t_prompt_processing / slot->num_prompt_tokens_processed , 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed );
1594
+ LOG_TEE (" %s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n " ,
1595
+ __func__, slot->t_token_generation , slot->n_decoded , slot->t_token_generation / slot->n_decoded , 1e3 / slot->t_token_generation * slot->n_decoded );
1596
+ LOG_TEE (" %s: total time = %10.2f ms\n " , __func__, slot->t_prompt_processing + slot->t_token_generation );
1597
+ }
1598
+
1584
1599
static json format_generation_settings (llama_server_context &llama, llama_client_slot* slot)
1585
1600
{
1586
1601
const auto eos_bias = slot->sparams .logit_bias .find (llama_token_eos (llama.ctx ));
@@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client
1606
1621
{" penalize_nl" , slot->sparams .penalize_nl },
1607
1622
{" stop" , slot->params .antiprompt },
1608
1623
{" n_predict" , slot->params .n_predict },
1609
- // {"n_keep", slot .params.n_keep},
1624
+ {" n_keep" , llama .params .n_keep },
1610
1625
{" ignore_eos" , ignore_eos},
1611
1626
{" stream" , slot->params .stream },
1612
1627
{" logit_bias" , slot->sparams .logit_bias },
@@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
1730
1745
slot->sparams .mirostat_tau = json_value (body, " mirostat_tau" , default_sparams.mirostat_tau );
1731
1746
slot->sparams .mirostat_eta = json_value (body, " mirostat_eta" , default_sparams.mirostat_eta );
1732
1747
slot->sparams .penalize_nl = json_value (body, " penalize_nl" , default_sparams.penalize_nl );
1733
- llama.params .n_keep = json_value (body, " n_keep" , - 1 );
1748
+ llama.params .n_keep = json_value (body, " n_keep" , 0 );
1734
1749
slot->params .seed = json_value (body, " seed" , default_params.seed );
1735
1750
slot->params .grammar = json_value (body, " grammar" , default_params.grammar );
1736
1751
slot->sparams .n_probs = json_value (body, " n_probs" , default_sparams.n_probs );
@@ -2089,6 +2104,7 @@ int main(int argc, char **argv)
2089
2104
}
2090
2105
2091
2106
const json data = format_final_response (llama, slot, completion_text, probs);
2107
+ slot_print_timings (slot);
2092
2108
slot->release ();
2093
2109
res.set_content (data.dump (-1 , ' ' , false , json::error_handler_t ::replace),
2094
2110
" application/json" );
@@ -2131,6 +2147,7 @@ int main(int argc, char **argv)
2131
2147
slot->generated_token_probs .begin (),
2132
2148
slot->generated_token_probs .begin () + sent_token_probs_index)
2133
2149
);
2150
+ slot_print_timings (slot);
2134
2151
const std::string str =
2135
2152
" data: " +
2136
2153
data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
@@ -2197,6 +2214,7 @@ int main(int argc, char **argv)
2197
2214
}
2198
2215
2199
2216
const json data = format_final_response (llama, slot, completion_text, probs);
2217
+ slot_print_timings (slot);
2200
2218
res.set_content (data.dump (-1 , ' ' , false , json::error_handler_t ::replace),
2201
2219
" application/json" );
2202
2220
} else {
@@ -2238,6 +2256,7 @@ int main(int argc, char **argv)
2238
2256
slot->generated_token_probs .begin (),
2239
2257
slot->generated_token_probs .begin () + sent_token_probs_index)
2240
2258
);
2259
+ slot_print_timings (slot);
2241
2260
const std::string str =
2242
2261
" data: " +
2243
2262
data.dump (-1 , ' ' , false , json::error_handler_t ::replace) +
0 commit comments