12
12
#include < map>
13
13
#include < string>
14
14
#include < vector>
15
+ #include < sstream>
15
16
16
17
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
17
18
#include < signal.h>
@@ -877,16 +878,9 @@ int main(int argc, char ** argv) {
877
878
params.interactive = true ;
878
879
params.antiprompt .push_back (" ### Instruction:\n\n " );
879
880
}
880
-
881
- // tokenize the reverse prompt
882
- std::vector<std::vector<gpt_vocab::id>> antipromptv_inp;
883
881
884
- for (auto antiprompt : params.antiprompt ) {
885
- antipromptv_inp.push_back (::llama_tokenize (vocab, antiprompt, false ));
886
- }
887
-
888
882
// enable interactive mode if reverse prompt is specified
889
- if (antipromptv_inp .size () != 0 ) {
883
+ if (params. antiprompt .size () != 0 ) {
890
884
params.interactive = true ;
891
885
}
892
886
@@ -910,15 +904,9 @@ int main(int argc, char ** argv) {
910
904
911
905
fprintf (stderr, " %s: interactive mode on.\n " , __func__);
912
906
913
- if (antipromptv_inp.size ()) {
914
- for (size_t apindex = 0 ; apindex < antipromptv_inp.size (); ++apindex) {
915
- auto antiprompt_inp = antipromptv_inp.at (apindex);
916
- fprintf (stderr, " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
917
- fprintf (stderr, " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
918
- for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
919
- fprintf (stderr, " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
920
- }
921
- fprintf (stderr, " \n " );
907
+ if (params.antiprompt .size ()) {
908
+ for (auto antiprompt : params.antiprompt ) {
909
+ fprintf (stderr, " Antiprompt: %s\n " , antiprompt);
922
910
}
923
911
}
924
912
}
@@ -1035,12 +1023,23 @@ int main(int argc, char ** argv) {
1035
1023
// check if we should prompt the user for more
1036
1024
if (params.interactive && embd_inp.size () <= input_consumed) {
1037
1025
// check for reverse prompt
1038
- for (auto antiprompt_inp : antipromptv_inp) {
1039
- if (antiprompt_inp.size () && std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), last_n_tokens.rbegin ())) {
1040
- // reverse prompt found
1026
+
1027
+ std::stringstream last_output_ss;
1028
+ for (auto id : last_n_tokens) {
1029
+ last_output_ss << vocab.id_to_token [id];
1030
+ }
1031
+ std::string last_output = last_output_ss.str ();
1032
+
1033
+ for (std::string antiprompt : params.antiprompt ) {
1034
+ if (last_output.find (antiprompt.c_str (), last_output.length () - antiprompt.length (), antiprompt.length ()) != std::string::npos) {
1041
1035
is_interacting = true ;
1042
1036
break ;
1043
1037
}
1038
+ /* if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
1039
+ // reverse prompt found
1040
+ is_interacting = true;
1041
+ break;
1042
+ }*/
1044
1043
}
1045
1044
if (is_interacting) {
1046
1045
if (params.instruct ) {
0 commit comments