Skip to content

Commit 226255b

Browse files
authored
server : fallback to default if client param is null (#2688)
* server : fallback to default if client param is null * server : do not overwrite 404 if status is 500 from exception_handler
1 parent 930523c commit 226255b

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

examples/server/server.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,33 +1056,42 @@ static json format_tokenizer_response(const std::vector<llama_token> &tokens)
10561056
{"tokens", tokens}};
10571057
}
10581058

1059+
template <typename T>
1060+
static T json_value(const json &body, const std::string &key, const T &default_value)
1061+
{
1062+
// Fallback null to default value
1063+
return body.contains(key) && !body.at(key).is_null()
1064+
? body.value(key, default_value)
1065+
: default_value;
1066+
}
1067+
10591068
static void parse_options_completion(const json &body, llama_server_context &llama)
10601069
{
10611070
gpt_params default_params;
10621071

1063-
llama.stream = body.value("stream", false);
1064-
llama.params.n_predict = body.value("n_predict", default_params.n_predict);
1065-
llama.params.top_k = body.value("top_k", default_params.top_k);
1066-
llama.params.top_p = body.value("top_p", default_params.top_p);
1067-
llama.params.tfs_z = body.value("tfs_z", default_params.tfs_z);
1068-
llama.params.typical_p = body.value("typical_p", default_params.typical_p);
1069-
llama.params.repeat_last_n = body.value("repeat_last_n", default_params.repeat_last_n);
1070-
llama.params.temp = body.value("temperature", default_params.temp);
1071-
llama.params.repeat_penalty = body.value("repeat_penalty", default_params.repeat_penalty);
1072-
llama.params.presence_penalty = body.value("presence_penalty", default_params.presence_penalty);
1073-
llama.params.frequency_penalty = body.value("frequency_penalty", default_params.frequency_penalty);
1074-
llama.params.mirostat = body.value("mirostat", default_params.mirostat);
1075-
llama.params.mirostat_tau = body.value("mirostat_tau", default_params.mirostat_tau);
1076-
llama.params.mirostat_eta = body.value("mirostat_eta", default_params.mirostat_eta);
1077-
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
1078-
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
1079-
llama.params.seed = body.value("seed", default_params.seed);
1080-
llama.params.prompt = body.value("prompt", default_params.prompt);
1081-
llama.params.grammar = body.value("grammar", default_params.grammar);
1082-
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
1072+
llama.stream = json_value(body, "stream", false);
1073+
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
1074+
llama.params.top_k = json_value(body, "top_k", default_params.top_k);
1075+
llama.params.top_p = json_value(body, "top_p", default_params.top_p);
1076+
llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
1077+
llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
1078+
llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
1079+
llama.params.temp = json_value(body, "temperature", default_params.temp);
1080+
llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
1081+
llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
1082+
llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
1083+
llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
1084+
llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
1085+
llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
1086+
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
1087+
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
1088+
llama.params.seed = json_value(body, "seed", default_params.seed);
1089+
llama.params.prompt = json_value(body, "prompt", default_params.prompt);
1090+
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
1091+
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
10831092

10841093
llama.params.logit_bias.clear();
1085-
if (body.value("ignore_eos", false))
1094+
if (json_value(body, "ignore_eos", false))
10861095
{
10871096
llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
10881097
}
@@ -1337,7 +1346,7 @@ int main(int argc, char **argv)
13371346
auto lock = llama.lock();
13381347

13391348
const json body = json::parse(req.body);
1340-
const std::string content = body.value("content", "");
1349+
const std::string content = json_value<std::string>(body, "content", "");
13411350
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
13421351
const json data = format_tokenizer_response(tokens);
13431352
return res.set_content(data.dump(), "application/json"); });
@@ -1350,7 +1359,7 @@ int main(int argc, char **argv)
13501359

13511360
llama.rewind();
13521361
llama_reset_timings(llama.ctx);
1353-
llama.params.prompt = body.value("content", "");
1362+
llama.params.prompt = json_value<std::string>(body, "content", "");
13541363
llama.params.n_predict = 0;
13551364
llama.loadPrompt();
13561365
llama.beginCompletion();
@@ -1379,7 +1388,7 @@ int main(int argc, char **argv)
13791388
{
13801389
if (res.status == 400) {
13811390
res.set_content("Invalid request", "text/plain");
1382-
} else {
1391+
} else if (res.status != 500) {
13831392
res.set_content("File Not Found", "text/plain");
13841393
res.status = 404;
13851394
} });

0 commit comments

Comments
 (0)