Skip to content

Commit b646ffa

Browse files
author
Johnman
committed
Check for reverse prompt by characters instead of tokens (ggml-org#292)
1 parent 074bea2 commit b646ffa

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

main.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <map>
1313
#include <string>
1414
#include <vector>
15+
#include <sstream>
1516

1617
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
1718
#include <signal.h>
@@ -877,16 +878,9 @@ int main(int argc, char ** argv) {
877878
params.interactive = true;
878879
params.antiprompt.push_back("### Instruction:\n\n");
879880
}
880-
881-
// tokenize the reverse prompt
882-
std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
883881

884-
for (auto antiprompt : params.antiprompt) {
885-
antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
886-
}
887-
888882
// enable interactive mode if reverse prompt is specified
889-
if (antipromptv_inp.size() != 0) {
883+
if (params.antiprompt.size() != 0) {
890884
params.interactive = true;
891885
}
892886

@@ -910,15 +904,9 @@ int main(int argc, char ** argv) {
910904

911905
fprintf(stderr, "%s: interactive mode on.\n", __func__);
912906

913-
if(antipromptv_inp.size()) {
914-
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
915-
auto antiprompt_inp = antipromptv_inp.at(apindex);
916-
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
917-
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
918-
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
919-
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
920-
}
921-
fprintf(stderr, "\n");
907+
if(params.antiprompt.size()) {
908+
for (auto antiprompt : params.antiprompt) {
909+
fprintf(stderr, "Antiprompt: %s\n", antiprompt);
922910
}
923911
}
924912
}
@@ -1035,12 +1023,23 @@ int main(int argc, char ** argv) {
10351023
// check if we should prompt the user for more
10361024
if (params.interactive && embd_inp.size() <= input_consumed) {
10371025
// check for reverse prompt
1038-
for (auto antiprompt_inp : antipromptv_inp) {
1039-
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
1040-
// reverse prompt found
1026+
1027+
std::stringstream last_output_ss;
1028+
for (auto id : last_n_tokens) {
1029+
last_output_ss << vocab.id_to_token[id];
1030+
}
1031+
std::string last_output = last_output_ss.str();
1032+
1033+
for (std::string antiprompt : params.antiprompt) {
1034+
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
10411035
is_interacting = true;
10421036
break;
10431037
}
1038+
/*if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
1039+
// reverse prompt found
1040+
is_interacting = true;
1041+
break;
1042+
}*/
10441043
}
10451044
if (is_interacting) {
10461045
if (params.instruct) {

0 commit comments

Comments
 (0)