@@ -26,6 +26,17 @@ struct server_params {
26
26
int32_t write_timeout = 600 ;
27
27
};
28
28
29
+ // completion token output with probabilities
30
+ struct completion_token_output {
31
+ struct token_prob {
32
+ llama_token tok;
33
+ float prob;
34
+ };
35
+
36
+ std::vector<token_prob> probs;
37
+ llama_token tok;
38
+ };
39
+
29
40
static size_t common_part (const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
30
41
size_t i;
31
42
for (i = 0 ; i < a.size () && i < b.size () && a[i] == b[i]; i++) {}
@@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
86
97
fflush (stdout);
87
98
}
88
99
100
+ // format incomplete utf-8 multibyte character for output
101
+ static std::string tokens_to_output_formatted_string (const llama_context * ctx, const llama_token token) {
102
+ std::string out = token == -1 ? " " : llama_token_to_str (ctx, token);
103
+ // if first bit is 1, meaning it's a partial character
104
+ if (out.size () > 0 && (out[0 ] & 0x80 ) == 0x80 ) {
105
+ std::stringstream ss;
106
+ ss<< std::hex << (out[0 ] & 0xff );
107
+ std::string res ( ss.str () );
108
+ out = " byte: \\ x" + res;
109
+ }
110
+ return out;
111
+ }
112
+
113
+ // convert a vector of completion_token_output to json
114
+ static json probs_vector_to_json (const llama_context * ctx, const std::vector<completion_token_output> probs) {
115
+ json out = json::array ();
116
+ for (const auto & prob : probs) {
117
+ json probs_for_token = json::array ();
118
+ for (const auto & p : prob.probs ) {
119
+ std::string tok_str = tokens_to_output_formatted_string (ctx, p.tok );
120
+ probs_for_token.push_back (json {
121
+ { " tok_str" , tok_str },
122
+ { " prob" , p.prob },
123
+ });
124
+ }
125
+ std::string tok_str = tokens_to_output_formatted_string (ctx, prob.tok );
126
+ out.push_back (json {
127
+ {" content" , tok_str},
128
+ {" probs" , probs_for_token},
129
+ });
130
+ }
131
+ return out;
132
+ }
133
+
89
134
static bool server_verbose = false ;
90
135
91
136
#if SERVER_VERBOSE != 1
@@ -107,6 +152,7 @@ struct llama_server_context {
107
152
bool stream = false ;
108
153
bool has_next_token = false ;
109
154
std::string generated_text;
155
+ std::vector<completion_token_output> generated_token_probs;
110
156
111
157
size_t num_tokens_predicted = 0 ;
112
158
size_t n_past = 0 ;
@@ -142,6 +188,7 @@ struct llama_server_context {
142
188
num_tokens_predicted = 0 ;
143
189
generated_text = " " ;
144
190
generated_text.reserve (params.n_ctx );
191
+ generated_token_probs.clear ();
145
192
truncated = false ;
146
193
stopped_eos = false ;
147
194
stopped_word = false ;
@@ -221,8 +268,9 @@ struct llama_server_context {
221
268
llama_set_rng_seed (ctx, params.seed );
222
269
}
223
270
224
- llama_token nextToken () {
225
- llama_token result = -1 ;
271
+ completion_token_output nextToken () {
272
+ completion_token_output result;
273
+ result.tok = -1 ;
226
274
227
275
if (embd.size () >= (size_t )params.n_ctx ) {
228
276
// Reset context
@@ -261,7 +309,8 @@ struct llama_server_context {
261
309
262
310
if (params.n_predict == 0 ) {
263
311
has_next_token = false ;
264
- return llama_token_eos ();
312
+ result.tok = llama_token_eos ();
313
+ return result;
265
314
}
266
315
267
316
// out of user input, sample next token
@@ -278,7 +327,7 @@ struct llama_server_context {
278
327
const float mirostat_tau = params.mirostat_tau ;
279
328
const float mirostat_eta = params.mirostat_eta ;
280
329
const bool penalize_nl = params.penalize_nl ;
281
- llama_token id = 0 ;
330
+ const int32_t n_probs = params. n_probs ;
282
331
283
332
{
284
333
auto * logits = llama_get_logits (ctx);
@@ -312,35 +361,42 @@ struct llama_server_context {
312
361
313
362
if (temp <= 0 ) {
314
363
// Greedy sampling
315
- id = llama_sample_token_greedy (ctx, &candidates_p);
364
+ result.tok = llama_sample_token_greedy (ctx, &candidates_p);
365
+ if (n_probs > 0 ) {
366
+ llama_sample_softmax (ctx, &candidates_p);
367
+ }
316
368
} else {
317
369
if (mirostat == 1 ) {
318
370
static float mirostat_mu = 2 .0f * mirostat_tau;
319
371
const int mirostat_m = 100 ;
320
372
llama_sample_temperature (ctx, &candidates_p, temp);
321
- id = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
373
+ result. tok = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
322
374
} else if (mirostat == 2 ) {
323
375
static float mirostat_mu = 2 .0f * mirostat_tau;
324
376
llama_sample_temperature (ctx, &candidates_p, temp);
325
- id = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
377
+ result. tok = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
326
378
} else {
327
379
// Temperature sampling
328
- llama_sample_top_k (ctx, &candidates_p, top_k, 1 );
329
- llama_sample_tail_free (ctx, &candidates_p, tfs_z, 1 );
330
- llama_sample_typical (ctx, &candidates_p, typical_p, 1 );
331
- llama_sample_top_p (ctx, &candidates_p, top_p, 1 );
380
+ size_t min_keep = std::max (1 , n_probs);
381
+ llama_sample_top_k (ctx, &candidates_p, top_k, min_keep);
382
+ llama_sample_tail_free (ctx, &candidates_p, tfs_z, min_keep);
383
+ llama_sample_typical (ctx, &candidates_p, typical_p, min_keep);
384
+ llama_sample_top_p (ctx, &candidates_p, top_p, min_keep);
332
385
llama_sample_temperature (ctx, &candidates_p, temp);
333
- id = llama_sample_token (ctx, &candidates_p);
386
+ result. tok = llama_sample_token (ctx, &candidates_p);
334
387
}
335
388
}
389
+
390
+ for (size_t i = 0 ; i < std::min (candidates_p.size , (size_t ) n_probs); ++i) {
391
+ result.probs .push_back ({candidates_p.data [i].id , candidates_p.data [i].p });
392
+ }
336
393
last_n_tokens.erase (last_n_tokens.begin ());
337
- last_n_tokens.push_back (id );
394
+ last_n_tokens.push_back (result. tok );
338
395
num_tokens_predicted++;
339
396
}
340
397
341
398
// add it to the context
342
- embd.push_back (id);
343
- result = id;
399
+ embd.push_back (result.tok );
344
400
// decrement remaining sampling budget
345
401
--n_remain;
346
402
@@ -382,12 +438,16 @@ struct llama_server_context {
382
438
return stop_pos;
383
439
}
384
440
385
- std::string doCompletion () {
386
- const llama_token token = nextToken ();
441
+ completion_token_output doCompletion () {
442
+ const completion_token_output token_with_probs = nextToken ();
387
443
388
- const std::string token_text = token == -1 ? " " : llama_token_to_str (ctx, token );
444
+ const std::string token_text = token_with_probs. tok == -1 ? " " : llama_token_to_str (ctx, token_with_probs. tok );
389
445
generated_text += token_text;
390
446
447
+ if (params.n_probs > 0 ) {
448
+ generated_token_probs.push_back (token_with_probs);
449
+ }
450
+
391
451
if (multibyte_pending > 0 ) {
392
452
multibyte_pending -= token_text.size ();
393
453
} else if (token_text.size () == 1 ) {
@@ -416,8 +476,8 @@ struct llama_server_context {
416
476
}
417
477
418
478
LOG_VERBOSE (" next token" , {
419
- { " token" , token },
420
- { " token_text" , llama_token_to_str (ctx, token ) },
479
+ { " token" , token_with_probs. tok },
480
+ { " token_text" , tokens_to_output_formatted_string (ctx, token_with_probs. tok ) },
421
481
{ " has_next_token" , has_next_token },
422
482
{ " n_remain" , n_remain },
423
483
{ " num_tokens_predicted" , num_tokens_predicted },
@@ -427,7 +487,7 @@ struct llama_server_context {
427
487
{ " stopping_word" , stopping_word },
428
488
});
429
489
430
- return token_text ;
490
+ return token_with_probs ;
431
491
}
432
492
433
493
std::vector<float > getEmbedding () {
@@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
669
729
{ " ignore_eos" , ignore_eos },
670
730
{ " stream" , llama.stream },
671
731
{ " logit_bias" , llama.params .logit_bias },
732
+ { " n_probs" , llama.params .n_probs },
672
733
};
673
734
}
674
735
@@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
678
739
};
679
740
}
680
741
681
- static json format_final_response (llama_server_context & llama, const std::string & content) {
682
- return json {
742
+ static json format_final_response (llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
743
+
744
+ json res = json {
683
745
{ " content" , content },
684
746
{ " stop" , true },
685
747
{ " model" , llama.params .model_alias },
@@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
692
754
{ " stopped_limit" , llama.stopped_limit },
693
755
{ " stopping_word" , llama.stopping_word },
694
756
};
757
+
758
+ if (llama.params .n_probs > 0 ) {
759
+ res[" completion_probabilities" ] = probs_vector_to_json (llama.ctx , probs);
760
+ }
761
+
762
+ return res;
695
763
}
696
764
697
- static json format_partial_response (const std::string & content) {
698
- return json {
765
+ static json format_partial_response (llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs ) {
766
+ json res = json {
699
767
{ " content" , content },
700
768
{ " stop" , false },
701
769
};
770
+
771
+ if (llama.params .n_probs > 0 ) {
772
+ res[" completion_probabilities" ] = probs_vector_to_json (llama.ctx , probs);
773
+ }
774
+
775
+ return res;
702
776
}
703
777
704
778
static json format_tokenizer_response (const std::vector<llama_token> & tokens) {
@@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
728
802
llama.params .n_keep = body.value (" n_keep" , default_params.n_keep );
729
803
llama.params .seed = body.value (" seed" , default_params.seed );
730
804
llama.params .prompt = body.value (" prompt" , default_params.prompt );
805
+ llama.params .n_probs = body.value (" n_probs" , default_params.n_probs );
731
806
732
807
llama.params .logit_bias .clear ();
733
808
if (body.value (" ignore_eos" , false )) {
@@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
830
905
size_t stop_pos = std::string::npos;
831
906
832
907
while (llama.has_next_token ) {
833
- const std::string token_text = llama.doCompletion ();
908
+ const completion_token_output token_with_probs = llama.doCompletion ();
909
+ const std::string token_text = llama_token_to_str (llama.ctx , token_with_probs.tok );
834
910
835
911
stop_pos = llama.findStoppingStrings (llama.generated_text ,
836
912
token_text.size (), STOP_FULL);
@@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
844
920
llama.generated_text .end ());
845
921
}
846
922
847
- const json data = format_final_response (llama, llama.generated_text );
923
+ const json data = format_final_response (llama, llama.generated_text , llama. generated_token_probs );
848
924
849
925
llama_print_timings (llama.ctx );
850
926
@@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
853
929
} else {
854
930
const auto chunked_content_provider = [&](size_t , DataSink & sink) {
855
931
size_t sent_count = 0 ;
932
+ size_t sent_token_probs_index = 0 ;
856
933
857
934
while (llama.has_next_token ) {
858
- const std::string token_text = llama.doCompletion ();
935
+ const completion_token_output token_with_probs = llama.doCompletion ();
936
+ const std::string token_text = llama_token_to_str (llama.ctx , token_with_probs.tok );
859
937
if (llama.multibyte_pending > 0 ) {
860
938
continue ;
861
939
}
@@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
878
956
const std::string to_send = llama.generated_text .substr (pos, stop_pos);
879
957
sent_count += to_send.size ();
880
958
959
+ std::vector<completion_token_output> probs_output = {};
960
+
961
+ if (llama.params .n_probs > 0 ) {
962
+ const std::vector<llama_token> to_send_toks = llama_tokenize (llama.ctx , to_send, false );
963
+ size_t probs_pos = std::min (sent_token_probs_index, llama.generated_token_probs .size ());
964
+ size_t probs_stop_pos = std::min (sent_token_probs_index + to_send_toks.size (), llama.generated_token_probs .size ());
965
+ if (probs_pos < probs_stop_pos) {
966
+ probs_output = std::vector<completion_token_output>(llama.generated_token_probs .begin () + probs_pos, llama.generated_token_probs .begin () + probs_stop_pos);
967
+ }
968
+ sent_token_probs_index = probs_stop_pos;
969
+ }
970
+
881
971
const json data = llama.has_next_token
882
- ? format_partial_response (to_send)
972
+ ? format_partial_response (llama, to_send, probs_output )
883
973
// Generation is done, send extra information.
884
- : format_final_response (llama, to_send);
974
+ : format_final_response (llama, to_send, llama. generated_token_probs );
885
975
886
976
const std::string str =
887
977
" data: " +
0 commit comments