@@ -538,7 +538,10 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
538
538
}
539
539
540
540
json format_generation_settings (llama_server_context &llama) {
541
- const bool ignore_eos = -INFINITY == llama.params .logit_bias [llama_token_eos ()];
541
+ const auto eos_bias = llama.params .logit_bias .find (llama_token_eos ());
542
+ const bool ignore_eos = eos_bias != llama.params .logit_bias .end () &&
543
+ eos_bias->second < 0 .0f && std::isinf (eos_bias->second );
544
+
542
545
return json {
543
546
{ " seed" , llama.params .seed },
544
547
{ " temp" , llama.params .temp },
@@ -659,10 +662,15 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
659
662
if (body[" logit_bias" ].is_array ()) {
660
663
int n_vocab = llama_n_vocab (llama.ctx );
661
664
for (const auto &el : body[" logit_bias" ]) {
662
- if (el.is_array () && el.size () == 2 && el[0 ].is_number_integer () && el[ 1 ]. is_number_float () ) {
665
+ if (el.is_array () && el.size () == 2 && el[0 ].is_number_integer ()) {
663
666
llama_token tok = el[0 ].get <llama_token>();
664
- if (tok < 0 || tok >= n_vocab) continue ;
665
- llama.params .logit_bias [tok] = el[1 ].get <float >();
667
+ if (tok >= 0 && tok < n_vocab) {
668
+ if (el[1 ].is_number_float ()) {
669
+ llama.params .logit_bias [tok] = el[1 ].get <float >();
670
+ } else if (el[1 ].is_boolean () && !el[1 ].get <bool >()) {
671
+ llama.params .logit_bias [tok] = -INFINITY;
672
+ }
673
+ }
666
674
}
667
675
}
668
676
}
0 commit comments