Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
int32_t n_probs = 0; // if greater than 1, output the probabilities of top n_probs tokens. Max 5

// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
Expand Down
126 changes: 99 additions & 27 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ struct server_params {
int32_t write_timeout = 600;
};

// completion string output with probabilities
struct completion_string_output {
struct token_prob {
std::string tok_str;
float prob;
};

std::vector<token_prob> probs;
std::string tok_str;
};

// completion token output with probabilities
struct completion_token_output {
struct token_prob {
llama_token tok;
float prob;
};

std::vector<token_prob> probs;
llama_token tok;
};

static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
Expand Down Expand Up @@ -107,6 +129,7 @@ struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_string_output> generated_text_probs;

size_t num_tokens_predicted = 0;
size_t n_past = 0;
Expand Down Expand Up @@ -137,6 +160,7 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
Expand Down Expand Up @@ -216,8 +240,9 @@ struct llama_server_context {
llama_set_rng_seed(ctx, params.seed);
}

llama_token nextToken() {
llama_token result = -1;
completion_token_output nextToken() {
completion_token_output result;
result.tok = -1;

if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
Expand Down Expand Up @@ -256,7 +281,8 @@ struct llama_server_context {

if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
result.tok = llama_token_eos();
return result;
}

// out of user input, sample next token
Expand All @@ -273,7 +299,7 @@ struct llama_server_context {
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
const int32_t n_probs = params.n_probs;

{
auto * logits = llama_get_logits(ctx);
Expand Down Expand Up @@ -307,35 +333,37 @@ struct llama_server_context {

if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}
for (size_t i = 0; i < std::min(candidates_p.size, std::min((size_t) n_probs, size_t(5))); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
}

// add it to the context
embd.push_back(id);
result = id;
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;

Expand Down Expand Up @@ -377,12 +405,22 @@ struct llama_server_context {
return stop_pos;
}

std::string doCompletion() {
const llama_token token = nextToken();
completion_string_output doCompletion() {
const completion_token_output token_with_probs = nextToken();
completion_string_output result;

const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
result.tok_str = token_text;
generated_text += token_text;

// iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob
for (const auto & prob : token_with_probs.probs) {
const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok);
result.probs.push_back({prob_text, prob.prob});
}

generated_text_probs.push_back(result);

if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
Expand Down Expand Up @@ -411,8 +449,8 @@ struct llama_server_context {
}

LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "token", token_with_probs.tok },
{ "token_text", llama_token_to_str(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
Expand All @@ -422,7 +460,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word },
});

return token_text;
return result;
}

std::vector<float> getEmbedding() {
Expand Down Expand Up @@ -664,6 +702,7 @@ static json format_generation_settings(llama_server_context & llama) {
{ "ignore_eos", ignore_eos },
{ "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
};
}

Expand All @@ -673,9 +712,26 @@ static json format_embedding_response(llama_server_context & llama) {
};
}

static json format_final_response(llama_server_context & llama, const std::string & content) {
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_string_output> & probs) {

json completion_probabilities_json = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
probs_for_token.push_back(json {
{ "tok_str", p.tok_str },
{ "prob", p.prob },
});
}
completion_probabilities_json.push_back(json {
{"content", prob.tok_str},
{"probs", probs_for_token},
});
}

return json {
{ "content", content },
{ "completion_probabilities", completion_probabilities_json},
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
Expand All @@ -689,11 +745,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
};
}

static json format_partial_response(const std::string & content) {
return json {
static json format_partial_response(const std::string & content, const completion_string_output & probs) {
json res = json {
{ "content", content },
{ "stop", false },
};

// iterate through probs.probs, and add to res
json probs_json = json::array();
for (const auto & prob : probs.probs) {
probs_json.push_back(json {
{ "tok_str", prob.tok_str },
{ "prob", prob.prob },
});
}
if (probs.probs.size() > 0) {
res["probs"] = probs_json;
}

return res;
}

static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
Expand Down Expand Up @@ -723,6 +793,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) {
Expand Down Expand Up @@ -825,7 +896,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
const std::string token_text = token_text_with_probs.tok_str;

stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
Expand All @@ -839,7 +911,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}

const json data = format_final_response(llama, llama.generated_text);
const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs);

llama_print_timings(llama.ctx);

Expand All @@ -850,7 +922,7 @@ int main(int argc, char ** argv) {
size_t sent_count = 0;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
if (llama.multibyte_pending > 0) {
continue;
}
Expand All @@ -859,24 +931,24 @@ int main(int argc, char ** argv) {

const std::string str_test = llama.generated_text.substr(pos);
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(),
STOP_PARTIAL);
}

const std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size();

const json data = llama.has_next_token
? format_partial_response(to_send)
? format_partial_response(to_send, token_text_with_probs)
// Generation is done, send extra information.
: format_final_response(llama, to_send);
: format_final_response(llama, to_send, {token_text_with_probs});

const std::string str =
"data: " +
Expand Down