Skip to content

Commit d7d2e6a

Browse files
server: add option to output probabilities for completion (#1962)
* server: add option to output probabilities for completion * server: fix issue when handling probability output for incomplete tokens for multibyte character generation * server: fix llama_sample_top_k order * examples/common.h: put all bool variables in gpt_params together
1 parent 46088f7 commit d7d2e6a

File tree

2 files changed

+122
-31
lines changed

2 files changed

+122
-31
lines changed

examples/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct gpt_params {
3131
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
3232
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3333
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
34-
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
34+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
3535

3636
// sampling parameters
3737
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
@@ -59,6 +59,7 @@ struct gpt_params {
5959
std::string lora_adapter = ""; // lora adapter path
6060
std::string lora_base = ""; // base model path for the lora adapter
6161

62+
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
6263
bool memory_f16 = true; // use f16 instead of f32 for memory kv
6364
bool random_prompt = false; // do not randomize prompt if none provided
6465
bool use_color = false; // use color to distinguish generations and inputs

examples/server/server.cpp

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ struct server_params {
2626
int32_t write_timeout = 600;
2727
};
2828

29+
// completion token output with probabilities
30+
struct completion_token_output {
31+
struct token_prob {
32+
llama_token tok;
33+
float prob;
34+
};
35+
36+
std::vector<token_prob> probs;
37+
llama_token tok;
38+
};
39+
2940
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
3041
size_t i;
3142
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
@@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
8697
fflush(stdout);
8798
}
8899

100+
// format incomplete utf-8 multibyte character for output
101+
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
102+
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
103+
// if first bit is 1, meaning it's a partial character
104+
if (out.size() > 0 && (out[0] & 0x80) == 0x80) {
105+
std::stringstream ss;
106+
ss<< std::hex << (out[0] & 0xff);
107+
std::string res ( ss.str() );
108+
out = "byte: \\x" + res;
109+
}
110+
return out;
111+
}
112+
113+
// convert a vector of completion_token_output to json
114+
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> probs) {
115+
json out = json::array();
116+
for (const auto & prob : probs) {
117+
json probs_for_token = json::array();
118+
for (const auto & p : prob.probs) {
119+
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
120+
probs_for_token.push_back(json {
121+
{ "tok_str", tok_str },
122+
{ "prob", p.prob },
123+
});
124+
}
125+
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
126+
out.push_back(json {
127+
{"content", tok_str},
128+
{"probs", probs_for_token},
129+
});
130+
}
131+
return out;
132+
}
133+
89134
static bool server_verbose = false;
90135

91136
#if SERVER_VERBOSE != 1
@@ -107,6 +152,7 @@ struct llama_server_context {
107152
bool stream = false;
108153
bool has_next_token = false;
109154
std::string generated_text;
155+
std::vector<completion_token_output> generated_token_probs;
110156

111157
size_t num_tokens_predicted = 0;
112158
size_t n_past = 0;
@@ -142,6 +188,7 @@ struct llama_server_context {
142188
num_tokens_predicted = 0;
143189
generated_text = "";
144190
generated_text.reserve(params.n_ctx);
191+
generated_token_probs.clear();
145192
truncated = false;
146193
stopped_eos = false;
147194
stopped_word = false;
@@ -221,8 +268,9 @@ struct llama_server_context {
221268
llama_set_rng_seed(ctx, params.seed);
222269
}
223270

224-
llama_token nextToken() {
225-
llama_token result = -1;
271+
completion_token_output nextToken() {
272+
completion_token_output result;
273+
result.tok = -1;
226274

227275
if (embd.size() >= (size_t)params.n_ctx) {
228276
// Reset context
@@ -261,7 +309,8 @@ struct llama_server_context {
261309

262310
if (params.n_predict == 0) {
263311
has_next_token = false;
264-
return llama_token_eos();
312+
result.tok = llama_token_eos();
313+
return result;
265314
}
266315

267316
// out of user input, sample next token
@@ -278,7 +327,7 @@ struct llama_server_context {
278327
const float mirostat_tau = params.mirostat_tau;
279328
const float mirostat_eta = params.mirostat_eta;
280329
const bool penalize_nl = params.penalize_nl;
281-
llama_token id = 0;
330+
const int32_t n_probs = params.n_probs;
282331

283332
{
284333
auto * logits = llama_get_logits(ctx);
@@ -312,35 +361,42 @@ struct llama_server_context {
312361

313362
if (temp <= 0) {
314363
// Greedy sampling
315-
id = llama_sample_token_greedy(ctx, &candidates_p);
364+
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
365+
if (n_probs > 0) {
366+
llama_sample_softmax(ctx, &candidates_p);
367+
}
316368
} else {
317369
if (mirostat == 1) {
318370
static float mirostat_mu = 2.0f * mirostat_tau;
319371
const int mirostat_m = 100;
320372
llama_sample_temperature(ctx, &candidates_p, temp);
321-
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
373+
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
322374
} else if (mirostat == 2) {
323375
static float mirostat_mu = 2.0f * mirostat_tau;
324376
llama_sample_temperature(ctx, &candidates_p, temp);
325-
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
377+
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
326378
} else {
327379
// Temperature sampling
328-
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
329-
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
330-
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
331-
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
380+
size_t min_keep = std::max(1, n_probs);
381+
llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
382+
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
383+
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
384+
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
332385
llama_sample_temperature(ctx, &candidates_p, temp);
333-
id = llama_sample_token(ctx, &candidates_p);
386+
result.tok = llama_sample_token(ctx, &candidates_p);
334387
}
335388
}
389+
390+
for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
391+
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
392+
}
336393
last_n_tokens.erase(last_n_tokens.begin());
337-
last_n_tokens.push_back(id);
394+
last_n_tokens.push_back(result.tok);
338395
num_tokens_predicted++;
339396
}
340397

341398
// add it to the context
342-
embd.push_back(id);
343-
result = id;
399+
embd.push_back(result.tok);
344400
// decrement remaining sampling budget
345401
--n_remain;
346402

@@ -382,12 +438,16 @@ struct llama_server_context {
382438
return stop_pos;
383439
}
384440

385-
std::string doCompletion() {
386-
const llama_token token = nextToken();
441+
completion_token_output doCompletion() {
442+
const completion_token_output token_with_probs = nextToken();
387443

388-
const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
444+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
389445
generated_text += token_text;
390446

447+
if (params.n_probs > 0) {
448+
generated_token_probs.push_back(token_with_probs);
449+
}
450+
391451
if (multibyte_pending > 0) {
392452
multibyte_pending -= token_text.size();
393453
} else if (token_text.size() == 1) {
@@ -416,8 +476,8 @@ struct llama_server_context {
416476
}
417477

418478
LOG_VERBOSE("next token", {
419-
{ "token", token },
420-
{ "token_text", llama_token_to_str(ctx, token) },
479+
{ "token", token_with_probs.tok },
480+
{ "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
421481
{ "has_next_token", has_next_token },
422482
{ "n_remain", n_remain },
423483
{ "num_tokens_predicted", num_tokens_predicted },
@@ -427,7 +487,7 @@ struct llama_server_context {
427487
{ "stopping_word", stopping_word },
428488
});
429489

430-
return token_text;
490+
return token_with_probs;
431491
}
432492

433493
std::vector<float> getEmbedding() {
@@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
669729
{ "ignore_eos", ignore_eos },
670730
{ "stream", llama.stream },
671731
{ "logit_bias", llama.params.logit_bias },
732+
{ "n_probs", llama.params.n_probs },
672733
};
673734
}
674735

@@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
678739
};
679740
}
680741

681-
static json format_final_response(llama_server_context & llama, const std::string & content) {
682-
return json {
742+
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
743+
744+
json res = json {
683745
{ "content", content },
684746
{ "stop", true },
685747
{ "model", llama.params.model_alias },
@@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
692754
{ "stopped_limit", llama.stopped_limit },
693755
{ "stopping_word", llama.stopping_word },
694756
};
757+
758+
if (llama.params.n_probs > 0) {
759+
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
760+
}
761+
762+
return res;
695763
}
696764

697-
static json format_partial_response(const std::string & content) {
698-
return json {
765+
static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
766+
json res = json {
699767
{ "content", content },
700768
{ "stop", false },
701769
};
770+
771+
if (llama.params.n_probs > 0) {
772+
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
773+
}
774+
775+
return res;
702776
}
703777

704778
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
@@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
728802
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
729803
llama.params.seed = body.value("seed", default_params.seed);
730804
llama.params.prompt = body.value("prompt", default_params.prompt);
805+
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
731806

732807
llama.params.logit_bias.clear();
733808
if (body.value("ignore_eos", false)) {
@@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
830905
size_t stop_pos = std::string::npos;
831906

832907
while (llama.has_next_token) {
833-
const std::string token_text = llama.doCompletion();
908+
const completion_token_output token_with_probs = llama.doCompletion();
909+
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
834910

835911
stop_pos = llama.findStoppingStrings(llama.generated_text,
836912
token_text.size(), STOP_FULL);
@@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
844920
llama.generated_text.end());
845921
}
846922

847-
const json data = format_final_response(llama, llama.generated_text);
923+
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
848924

849925
llama_print_timings(llama.ctx);
850926

@@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
853929
} else {
854930
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
855931
size_t sent_count = 0;
932+
size_t sent_token_probs_index = 0;
856933

857934
while (llama.has_next_token) {
858-
const std::string token_text = llama.doCompletion();
935+
const completion_token_output token_with_probs = llama.doCompletion();
936+
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
859937
if (llama.multibyte_pending > 0) {
860938
continue;
861939
}
@@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
878956
const std::string to_send = llama.generated_text.substr(pos, stop_pos);
879957
sent_count += to_send.size();
880958

959+
std::vector<completion_token_output> probs_output = {};
960+
961+
if (llama.params.n_probs > 0) {
962+
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
963+
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
964+
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
965+
if (probs_pos < probs_stop_pos) {
966+
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
967+
}
968+
sent_token_probs_index = probs_stop_pos;
969+
}
970+
881971
const json data = llama.has_next_token
882-
? format_partial_response(to_send)
972+
? format_partial_response(llama, to_send, probs_output)
883973
// Generation is done, send extra information.
884-
: format_final_response(llama, to_send);
974+
: format_final_response(llama, to_send, llama.generated_token_probs);
885975

886976
const std::string str =
887977
"data: " +

0 commit comments

Comments
 (0)