Skip to content

Commit 2d9f11d

Browse files
committed
fixed premature end due stop word
1 parent fd64f04 commit 2d9f11d

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

examples/server/chat.mjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async function chat_completion(question) {
8686
n_predict: 256,
8787
cache_prompt: no_cached_prompt === "false",
8888
slot_id: slot_id,
89-
stop: ["### Human:"], // stop completion after generating this
89+
stop: ["\n### Human:"], // stop completion after generating this
9090
grammar,
9191
stream: true,
9292
})

examples/server/server.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ struct llama_client_slot
316316
struct slot_params params;
317317
struct llama_sampling_params sparams;
318318
llama_sampling_context ctx_sampling;
319+
bool has_next_token = true;
319320

320321
// grammar props
321322
grammar_parser::parse_state parsed_grammar;
@@ -710,9 +711,14 @@ struct llama_server_context
710711
if (pos != std::string::npos &&
711712
(stop_pos == std::string::npos || pos < stop_pos))
712713
{
714+
if (type == STOP_FULL)
715+
{
716+
slot.stopped_word = true;
717+
slot.stopping_word = word;
718+
slot.has_next_token = false;
719+
}
713720
stop_pos = pos;
714-
slot.stopped_word = true;
715-
slot.stopping_word = word;
721+
716722
}
717723
}
718724
return stop_pos;
@@ -727,6 +733,8 @@ struct llama_server_context
727733

728734
// search stop word and delete it
729735
slot.generated_text += token_str;
736+
slot.has_next_token = true;
737+
730738
size_t pos = std::min(slot.sent_count, slot.generated_text.size());
731739
const std::string str_test = slot.generated_text.substr(pos);
732740
bool is_stop_full = false;
@@ -744,15 +752,13 @@ struct llama_server_context
744752
}
745753

746754
// check if there is any token to predict
747-
bool has_next_token = !is_stop_full && stop_pos > 0;
748-
if(stop_pos == std::string::npos) {
755+
if(stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
749756
// no send the stop word in the response
750757
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
751758
slot.sent_count += result.text_to_send.size();
752-
has_next_token = true;
759+
// add the token to slot queue and cache
760+
slot.addTokenString(result);
753761
}
754-
// add the token to slot queue and cache
755-
slot.addTokenString(result);
756762
if (slot.multibyte_pending > 0)
757763
{
758764
slot.multibyte_pending -= token_str.size();
@@ -781,37 +787,37 @@ struct llama_server_context
781787
}
782788
}
783789

784-
if (slot.multibyte_pending > 0 && !has_next_token)
790+
if (slot.multibyte_pending > 0 && !slot.has_next_token)
785791
{
786-
has_next_token = true;
792+
slot.has_next_token = true;
787793
}
788794

789795
// check the limits
790796
if (
791-
slot.n_decoded > 2 && has_next_token && !slot.hasBudget(params))
797+
slot.n_decoded > 2 && slot.has_next_token && !slot.hasBudget(params))
792798
{
793799
slot.stopped_limit = true;
794-
has_next_token = false;
800+
slot.has_next_token = false;
795801
}
796802

797803
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(ctx)){
798804
slot.stopped_eos = true;
799-
has_next_token = false;
805+
slot.has_next_token = false;
800806
LOG_VERBOSE("eos token found", {});
801807
}
802808

803809
LOG_VERBOSE("next token", {
804810
{"token", result.tok},
805811
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
806-
{"has_next_token", has_next_token},
812+
{"has_next_token", slot.has_next_token},
807813
{"n_remain", slot.n_remaining},
808814
{"num_tokens_predicted", slot.num_tokens_predicted},
809815
{"stopped_eos", slot.stopped_eos},
810816
{"stopped_word", slot.stopped_word},
811817
{"stopped_limit", slot.stopped_limit},
812818
{"stopping_word", slot.stopping_word},
813819
});
814-
return has_next_token; // continue
820+
return slot.has_next_token; // continue
815821
}
816822

817823
#ifdef SERVER_MULTIMODAL_SUPPORT
@@ -2293,7 +2299,6 @@ int main(int argc, char **argv)
22932299
const json body = json::parse(req.body);
22942300
llama_client_slot* slot = llama.getSlot(-1);
22952301
slot->reset();
2296-
//llama_reset_timings(llama.ctx);
22972302
if (body.count("content") != 0)
22982303
{
22992304
slot->prompt = body["content"];

0 commit comments

Comments
 (0)