@@ -125,7 +125,7 @@ enum slot_command {
125
125
struct slot_params {
126
126
bool stream = true ;
127
127
uint32_t seed = -1 ; // RNG seed
128
- int n_keep = 0 ; // RNG seed
128
+ int n_keep = 0 ; // number of tokens to keep from initial prompt
129
129
int32_t n_predict = -1 ; // new tokens to predict
130
130
std::string grammar = " " ; // optional BNF-like grammar to constrain sampling
131
131
bool cache_prompt = false ; // remember a the prompt to avoid reprocessing all prompt
@@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
262
262
return out;
263
263
}
264
264
265
+ struct llama_sampling_context * llama_sampling_init_srv (const struct llama_sampling_params sparams, std::string grammar, int n_ctx) {
266
+ struct llama_sampling_context * result = new llama_sampling_context ();
267
+
268
+ result->params = sparams;
269
+ result->grammar = nullptr ;
270
+
271
+ // if there is a grammar, parse it
272
+ if (!grammar.empty ()) {
273
+ result->parsed_grammar = grammar_parser::parse (grammar.c_str ());
274
+
275
+ // will be empty (default) if there are parse errors
276
+ if (result->parsed_grammar .rules .empty ()) {
277
+ fprintf (stderr, " %s: failed to parse grammar\n " , __func__);
278
+ return nullptr ;
279
+ }
280
+
281
+ std::vector<const llama_grammar_element *> grammar_rules (result->parsed_grammar .c_rules ());
282
+
283
+ result->grammar = llama_grammar_init (
284
+ grammar_rules.data (),
285
+ grammar_rules.size (), result->parsed_grammar .symbol_ids .at (" root" ));
286
+ }
287
+
288
+ result->prev .resize (n_ctx);
289
+
290
+ return result;
291
+ }
292
+
265
293
struct slot_image {
266
294
clip_image_u8 img_data;
267
295
bool request_encode_image = false ;
@@ -287,7 +315,6 @@ struct llama_client_slot
287
315
int num_tokens_predicted = 0 ;
288
316
llama_token sampled;
289
317
std::vector<llama_token> cache_tokens;
290
- std::vector<llama_token> last_n_tokens;
291
318
std::vector<completion_token_output> generated_token_probs;
292
319
int sent_tokens = 0 ;
293
320
slot_state state = IDLE;
@@ -307,13 +334,12 @@ struct llama_client_slot
307
334
double t_token_generation; // ms
308
335
309
336
struct slot_params params;
337
+
338
+ // sampling
310
339
struct llama_sampling_params sparams;
311
- llama_sampling_context ctx_sampling;
340
+ llama_sampling_context* ctx_sampling = nullptr ;
312
341
bool has_next_token = true ;
313
-
314
- // grammar props
315
- grammar_parser::parse_state parsed_grammar;
316
- llama_grammar *grammar = nullptr ;
342
+ int max_context_size = 0 ;
317
343
318
344
// multimodal
319
345
std::vector<slot_image> images;
@@ -332,47 +358,26 @@ struct llama_client_slot
332
358
infill = false ;
333
359
clean_tokens ();
334
360
335
- if (grammar != nullptr ) {
336
- llama_grammar_free (grammar);
337
- grammar = nullptr ;
338
- ctx_sampling.params = sparams;
339
- ctx_sampling.grammar = NULL ;
361
+ if (ctx_sampling != nullptr ) {
362
+ llama_sampling_free (ctx_sampling);
340
363
}
341
364
365
+ ctx_sampling = llama_sampling_init_srv (sparams, params.grammar , max_context_size);
366
+
342
367
for (slot_image img : images) {
343
368
free (img.image_embedding );
344
369
delete[] img.img_data .data ;
345
370
img.prefix_prompt = " " ;
346
371
}
372
+
347
373
images.clear ();
348
374
// llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
349
375
}
350
376
351
377
bool loadGrammar (llama_token eos)
352
378
{
353
- if (!params.grammar .empty ()) {
354
- parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
355
- // will be empty (default) if there are parse errors
356
- if (parsed_grammar.rules .empty ()) {
357
- LOG_ERROR (" grammar parse error" , {{" grammar" , params.grammar }});
358
- return false ;
359
- }
360
- grammar_parser::print_grammar (stderr, parsed_grammar);
361
-
362
- {
363
- auto it = sparams.logit_bias .find (eos);
364
- if (it != sparams.logit_bias .end () && it->second == -INFINITY) {
365
- LOG_WARNING (" EOS token is disabled, which will cause most grammars to fail" , {});
366
- }
367
- }
368
-
369
- std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
370
- grammar = llama_grammar_init (
371
- grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
372
- }
373
- ctx_sampling.params = sparams;
374
- ctx_sampling.grammar = grammar;
375
- return true ;
379
+ ctx_sampling = llama_sampling_init_srv (sparams, params.grammar , max_context_size);
380
+ return ctx_sampling != nullptr ;
376
381
}
377
382
378
383
bool hasBudget (gpt_params &global_params) {
@@ -448,7 +453,6 @@ struct llama_server_context
448
453
llama_model *model = nullptr ;
449
454
llama_context *ctx = nullptr ;
450
455
llama_batch batch;
451
- std::vector<llama_token_data> candidates;
452
456
bool all_slots_are_idle = false ;
453
457
gpt_params params;
454
458
int n_ctx;
@@ -468,11 +472,6 @@ struct llama_server_context
468
472
llama_free_model (model);
469
473
model = nullptr ;
470
474
}
471
- for (auto &slot : slots) {
472
- if (slot.grammar ) {
473
- llama_grammar_free (slot.grammar );
474
- }
475
- }
476
475
}
477
476
478
477
bool loadModel (const gpt_params ¶ms_)
@@ -510,7 +509,6 @@ struct llama_server_context
510
509
}
511
510
n_ctx = llama_n_ctx (ctx);
512
511
n_vocab = llama_n_vocab (model);
513
- candidates.reserve (n_vocab);
514
512
return true ;
515
513
}
516
514
@@ -529,13 +527,12 @@ struct llama_server_context
529
527
{
530
528
llama_client_slot slot;
531
529
slot.id = i;
532
- slot.last_n_tokens .resize (max_ctx_per_slot);
533
- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
530
+ slot.max_context_size = max_ctx_per_slot;
534
531
slot.reset ();
535
532
LOG_TEE (" -> Slot %i - max context: %i\n " , slot.id , max_ctx_per_slot);
536
533
slots.push_back (slot);
537
534
}
538
- batch = llama_batch_init (n_ctx, 0 );
535
+ batch = llama_batch_init (n_ctx, 0 , 1 );
539
536
// empty system prompt
540
537
system_prompt = " " ;
541
538
num_tokens_system = 0 ;
@@ -626,10 +623,7 @@ struct llama_server_context
626
623
627
624
for (int32_t i = 0 ; i < batch.n_tokens ; ++i)
628
625
{
629
- batch.token [i] = tokens_system[i];
630
- batch.pos [i] = i;
631
- batch.seq_id [i] = 0 ;
632
- batch.logits [i] = false ;
626
+ llama_batch_add (batch, tokens_system[i], i, { 0 }, false );
633
627
}
634
628
635
629
if (llama_decode (ctx, batch) != 0 )
@@ -726,8 +720,6 @@ struct llama_server_context
726
720
727
721
bool processToken (completion_token_output & result, llama_client_slot & slot) {
728
722
// remember which tokens were sampled - used for repetition penalties during sampling
729
- slot.last_n_tokens .erase (slot.last_n_tokens .begin ());
730
- slot.last_n_tokens .push_back (result.tok );
731
723
const std::string token_str = llama_token_to_piece (ctx, result.tok );
732
724
slot.sampled = result.tok ;
733
725
@@ -859,11 +851,12 @@ struct llama_server_context
859
851
const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
860
852
llama_batch batch_view = {
861
853
n_tokens,
862
- batch.token + i,
854
+ batch.token + i,
863
855
nullptr ,
864
- batch.pos + i,
865
- batch.seq_id + i,
866
- batch.logits + i,
856
+ batch.pos + i,
857
+ batch.n_seq_id + i,
858
+ batch.seq_id + i,
859
+ batch.logits + i,
867
860
0 , 0 , 0 , // unused
868
861
};
869
862
if (llama_decode (ctx, batch_view)) {
@@ -878,8 +871,8 @@ struct llama_server_context
878
871
if (n_eval > n_batch) {
879
872
n_eval = n_batch;
880
873
}
881
- llama_batch batch = {int32_t (n_eval), nullptr , (img.image_embedding + i * n_embd), nullptr , nullptr , nullptr , slot.n_past , 1 , 0 , };
882
- if (llama_decode (ctx, batch )) {
874
+ llama_batch batch_img = {int32_t (n_eval), nullptr , (img.image_embedding + i * n_embd), nullptr , nullptr , nullptr , nullptr , slot.n_past , 1 , 0 , };
875
+ if (llama_decode (ctx, batch_img )) {
883
876
LOG_TEE (" %s : failed to eval image\n " , __func__);
884
877
return false ;
885
878
}
@@ -894,10 +887,7 @@ struct llama_server_context
894
887
(json)(slot.images [image_idx].prefix_prompt );
895
888
std::vector<llama_token> append_tokens = tokenize (json_prompt, false ); // has next image
896
889
for (int i = 0 ; i < append_tokens.size (); ++i) {
897
- batch.token [batch.n_tokens ] = append_tokens[i];
898
- batch.pos [batch.n_tokens ] = slot.n_past ;
899
- batch.seq_id [batch.n_tokens ] = slot.id ;
900
- batch.logits [batch.n_tokens ] = false ;
890
+ llama_batch_add (batch, append_tokens[i], slot.n_past , { slot.id }, true );
901
891
slot.n_past += 1 ;
902
892
batch.n_tokens += 1 ;
903
893
}
@@ -922,7 +912,6 @@ struct llama_server_context
922
912
std::this_thread::sleep_for (std::chrono::milliseconds (5 ));
923
913
}
924
914
925
- // context shift takes effect only when there is a single slot
926
915
for (llama_client_slot &slot : slots) {
927
916
if (slot.isProcessing () && slot.cache_tokens .size () >= (size_t )max_ctx_per_slot)
928
917
{
@@ -976,16 +965,12 @@ struct llama_server_context
976
965
continue ;
977
966
}
978
967
979
- batch.token [batch.n_tokens ] = slot.sampled ;
980
- batch.pos [batch.n_tokens ] = num_tokens_system + slot.n_past ;
981
- batch.seq_id [batch.n_tokens ] = slot.id ;
982
- batch.logits [batch.n_tokens ] = true ;
968
+ slot.i_batch = batch.n_tokens ;
969
+
970
+ llama_batch_add (batch, slot.sampled , num_tokens_system + slot.n_past , { slot.id }, true );
983
971
984
972
slot.n_decoded += 1 ;
985
- slot.i_batch = batch.n_tokens ;
986
973
slot.n_past += 1 ;
987
-
988
- batch.n_tokens += 1 ;
989
974
}
990
975
// process in chunks of params.n_batch
991
976
int32_t n_batch = params.n_batch ;
@@ -1026,7 +1011,7 @@ struct llama_server_context
1026
1011
slot.num_prompt_tokens = prompt_tokens.size ();
1027
1012
1028
1013
if (!slot.params .cache_prompt ) {
1029
- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end (), 0 );
1014
+ std::fill (slot.ctx_sampling -> prev .begin (), slot.ctx_sampling -> prev .end (), 0 );
1030
1015
slot.n_past = 0 ;
1031
1016
slot.num_prompt_tokens_processed = slot.num_prompt_tokens ;
1032
1017
} else {
@@ -1038,23 +1023,27 @@ struct llama_server_context
1038
1023
// if input prompt is too big, truncate like normal
1039
1024
if (slot.num_prompt_tokens >= (size_t )max_ctx_per_slot)
1040
1025
{
1026
+ // applied bug of #3661
1041
1027
const int n_left = max_ctx_per_slot - slot.params .n_keep ;
1028
+ const int n_block_size = n_left / 2 ;
1029
+ const int erased_blocks = (slot.num_prompt_tokens - slot.params .n_keep - n_block_size) / n_block_size;
1042
1030
std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + slot.params .n_keep );
1043
1031
// Use half the left-over space in the context for the prompt
1044
- new_tokens.insert (new_tokens.end (), prompt_tokens.end () - n_left / 2 , prompt_tokens.end ());
1032
+ new_tokens.insert (new_tokens.end (), prompt_tokens.end () + slot. params . n_keep + erased_blocks * n_block_size , prompt_tokens.end ());
1045
1033
LOG_VERBOSE (" input truncated" , {
1046
- {" n_ctx" , n_ctx },
1047
- {" n_keep" , params.n_keep },
1034
+ {" n_ctx" , max_ctx_per_slot },
1035
+ {" n_keep" , slot. params .n_keep },
1048
1036
{" n_left" , n_left},
1049
1037
{" new_tokens" , tokens_to_str (ctx, new_tokens.cbegin (), new_tokens.cend ())},
1050
1038
});
1051
1039
slot.truncated = true ;
1052
1040
prompt_tokens = new_tokens;
1053
1041
slot.num_prompt_tokens = prompt_tokens.size ();
1042
+ GGML_ASSERT (slot.num_prompt_tokens < (size_t )max_ctx_per_slot);
1054
1043
}
1055
1044
const size_t ps = slot.num_prompt_tokens ;
1056
- std::fill (slot.last_n_tokens .begin (), slot.last_n_tokens .end () - ps, 0 );
1057
- std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.last_n_tokens .end () - ps);
1045
+ std::fill (slot.ctx_sampling -> prev .begin (), slot.ctx_sampling -> prev .end () - ps, 0 );
1046
+ std::copy (prompt_tokens.begin (), prompt_tokens.end (), slot.ctx_sampling -> prev .end () - ps);
1058
1047
slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
1059
1048
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past ;
1060
1049
LOG_TEE (" slot %i - in cache: %i tokens | to process: %i tokens\n " , slot.id , slot.n_past , slot.num_prompt_tokens_processed );
@@ -1081,11 +1070,7 @@ struct llama_server_context
1081
1070
// process the prefix of first image
1082
1071
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize (slot.images [0 ].prefix_prompt , true ) : prompt_tokens;
1083
1072
for (; slot.n_past < prefix_tokens.size (); ++slot.n_past ) {
1084
- batch.token [batch.n_tokens ] = prefix_tokens[slot.n_past ];
1085
- batch.pos [batch.n_tokens ] = slot.n_past + num_tokens_system;
1086
- batch.seq_id [batch.n_tokens ] = slot.id ;
1087
- batch.logits [batch.n_tokens ] = false ;
1088
- batch.n_tokens += 1 ;
1073
+ llama_batch_add (batch, prefix_tokens[slot.n_past ], num_tokens_system + slot.n_past , { slot.id }, false );
1089
1074
}
1090
1075
1091
1076
if (ingest_images && !ingestImages (slot, n_batch)) {
@@ -1113,11 +1098,12 @@ struct llama_server_context
1113
1098
const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
1114
1099
llama_batch batch_view = {
1115
1100
n_tokens,
1116
- batch.token + i,
1101
+ batch.token + i,
1117
1102
nullptr ,
1118
- batch.pos + i,
1119
- batch.seq_id + i,
1120
- batch.logits + i,
1103
+ batch.pos + i,
1104
+ batch.n_seq_id + i,
1105
+ batch.seq_id + i,
1106
+ batch.logits + i,
1121
1107
0 , 0 , 0 , // unused
1122
1108
};
1123
1109
@@ -1150,25 +1136,27 @@ struct llama_server_context
1150
1136
}
1151
1137
1152
1138
completion_token_output result;
1153
- const llama_token id = llama_sampling_sample (ctx, NULL , slot.ctx_sampling , slot.last_n_tokens , candidates, slot.i_batch - i);
1139
+ const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , slot.i_batch - i);
1140
+
1141
+ llama_sampling_accept (slot.ctx_sampling , ctx, id);
1154
1142
1155
1143
if (slot.n_decoded == 1 ) {
1156
1144
slot.t_start_genereration = ggml_time_us ();
1157
1145
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt ) / 1e3 ;
1158
1146
}
1159
1147
1160
- llama_token_data_array candidates_p = { candidates. data (), candidates .size (), false };
1148
+ llama_token_data_array cur_p = { slot. ctx_sampling -> cur . data (), slot. ctx_sampling -> cur .size (), false };
1161
1149
result.tok = id;
1162
1150
const int32_t n_probs = slot.sparams .n_probs ;
1163
1151
if (slot.sparams .temp <= 0 && n_probs > 0 )
1164
1152
{
1165
1153
// For llama_sample_token_greedy we need to sort candidates
1166
- llama_sample_softmax (ctx, &candidates_p );
1154
+ llama_sample_softmax (ctx, &cur_p );
1167
1155
}
1168
1156
1169
- for (size_t i = 0 ; i < std::min (candidates_p .size , (size_t )n_probs); ++i)
1157
+ for (size_t i = 0 ; i < std::min (cur_p .size , (size_t )n_probs); ++i)
1170
1158
{
1171
- result.probs .push_back ({candidates_p .data [i].id , candidates_p .data [i].p });
1159
+ result.probs .push_back ({cur_p .data [i].id , cur_p .data [i].p });
1172
1160
}
1173
1161
1174
1162
if (!processToken (result, slot)) {
0 commit comments