Skip to content

Commit d48c242

Browse files
dranger003ggerganov
authored andcommitted
main : support special tokens as reverse/anti prompt (ggml-org#5847)
* Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent a0a1ca0 commit d48c242

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

examples/main/main.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ int main(int argc, char ** argv) {
511511
std::vector<llama_token> embd;
512512
std::vector<llama_token> embd_guidance;
513513

514+
// tokenized antiprompts
515+
std::vector<std::vector<llama_token>> antiprompt_ids;
516+
517+
antiprompt_ids.reserve(params.antiprompt.size());
518+
for (const std::string & antiprompt : params.antiprompt) {
519+
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
520+
}
521+
514522
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
515523

516524
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
@@ -769,6 +777,18 @@ int main(int argc, char ** argv) {
769777
}
770778
}
771779

780+
// check for reverse prompt using special tokens
781+
llama_token last_token = llama_sampling_last(ctx_sampling);
782+
for (std::vector<llama_token> ids : antiprompt_ids) {
783+
if (ids.size() == 1 && last_token == ids[0]) {
784+
if (params.interactive) {
785+
is_interacting = true;
786+
}
787+
is_antiprompt = true;
788+
break;
789+
}
790+
}
791+
772792
if (is_antiprompt) {
773793
LOG("found antiprompt: %s\n", last_output.c_str());
774794
}

0 commit comments

Comments
 (0)