Skip to content

Commit 88cc7bb

Browse files
authored
Stuff with logits
1 parent 0bc0477 commit 88cc7bb

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,10 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
538538
}
539539

540540
json format_generation_settings(llama_server_context &llama) {
541-
const bool ignore_eos = -INFINITY == llama.params.logit_bias[llama_token_eos()];
541+
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
542+
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
543+
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
544+
542545
return json {
543546
{ "seed", llama.params.seed },
544547
{ "temp", llama.params.temp },
@@ -659,10 +662,15 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
659662
if (body["logit_bias"].is_array()) {
660663
int n_vocab = llama_n_vocab(llama.ctx);
661664
for (const auto &el : body["logit_bias"]) {
662-
if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) {
665+
if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) {
663666
llama_token tok = el[0].get<llama_token>();
664-
if (tok < 0 || tok >= n_vocab) continue;
665-
llama.params.logit_bias[tok] = el[1].get<float>();
667+
if (tok >= 0 && tok < n_vocab) {
668+
if (el[1].is_number_float()) {
669+
llama.params.logit_bias[tok] = el[1].get<float>();
670+
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
671+
llama.params.logit_bias[tok] = -INFINITY;
672+
}
673+
}
666674
}
667675
}
668676
}

0 commit comments

Comments
 (0)