From 3eed8c0914ffbadd029f12ecb422dd098e62e5be Mon Sep 17 00:00:00 2001 From: Joshua Williams Date: Tue, 21 Mar 2023 12:26:36 -0500 Subject: [PATCH 1/2] Initial implementation of stop keywords --- main.cpp | 104 +++++++++++++++++++++++++++++++++--------------------- utils.cpp | 2 ++ utils.h | 1 + 3 files changed, 67 insertions(+), 40 deletions(-) diff --git a/main.cpp b/main.cpp index 4b220c8cfcc99..89708e69c9562 100644 --- a/main.cpp +++ b/main.cpp @@ -1016,6 +1016,13 @@ int main(int argc, char ** argv) { } } } + + if(params.stop_keyword.size()) { + for (auto stop_keyword : params.stop_keyword) { + fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); + } + } + fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); fprintf(stderr, "\n\n"); @@ -1129,58 +1136,75 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_RESET); } - // in interactive mode, and not currently processing queued inputs; - // check if we should prompt the user for more - if (params.interactive && (int) embd_inp.size() <= input_consumed) { - // check for reverse prompt + // If we are not processing queued inputs, check for reverse prompt and stop keywords + if((int) embd_inp.size() <= input_consumed) { + // Build the output string + // TODO - Recomputing this whole string every iteration is not efficient std::string last_output; for (auto id : last_n_tokens) { last_output += vocab.id_to_token[id]; } - // Check if each of the reverse prompts appears at the end of the output. - for (std::string antiprompt : params.antiprompt) { - if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; + // Check for stop keywords + bool stop = false; + for (std::string stop_keyword : params.stop_keyword) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; break; } } - if (is_interacting) { - if (params.instruct) { - input_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - - printf("\n> "); + if(stop) { + break; + } + + // in interactive mode, and not currently processing queued inputs; + // check if we should prompt the user for more + if (params.interactive) { + + // Check if each of the reverse prompts appears at the end of the output. + for (std::string antiprompt : params.antiprompt) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { + is_interacting = true; + break; + } } - - // currently being interactive - if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); - std::string buffer; - std::string line; - bool another_line = true; - do { - std::getline(std::cin, line); - if (line.empty() || line.back() != '\\') { - another_line = false; - } else { - line.pop_back(); // Remove the continue character + if (is_interacting) { + if (params.instruct) { + input_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + + printf("\n> "); } - buffer += line + '\n'; // Append the line to the result - } while (another_line); - if (params.use_color) printf(ANSI_COLOR_RESET); - - std::vector line_inp = ::llama_tokenize(vocab, buffer, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + + // currently being interactive + if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); + std::string buffer; + std::string line; + bool another_line = true; + do { + std::getline(std::cin, line); + if (line.empty() || line.back() != '\\') { + another_line = false; + } else { + line.pop_back(); // Remove the continue character + } + buffer += line + '\n'; // Append the line to the result + } while (another_line); + if (params.use_color) printf(ANSI_COLOR_RESET); + + std::vector line_inp = ::llama_tokenize(vocab, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + remaining_tokens -= line_inp.size(); + + input_noecho = true; // do not echo this again } - - remaining_tokens -= line_inp.size(); - - input_noecho = true; // do not echo this again + is_interacting = false; } - is_interacting = false; } // end of text token diff --git a/utils.cpp b/utils.cpp index 7c6864c8f4b86..a51e8c8e5eace 100644 --- a/utils.cpp +++ b/utils.cpp @@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "-r" || arg == "--reverse-prompt") { params.antiprompt.push_back(argv[++i]); + } else if (arg == "--stop") { + params.stop_keyword.push_back(argv[++i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { diff --git a/utils.h b/utils.h index 6693775c57d79..712d47dbe6d27 100644 --- a/utils.h +++ b/utils.h @@ -32,6 +32,7 @@ struct gpt_params { std::string prompt = ""; std::vector antiprompt; // string upon seeing which more user input is prompted + std::vector stop_keyword; // string upon seeing which the model will stop bool memory_f16 = false; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided From ea367074f8e1e88df170dc7747febed43f487aa3 Mon Sep 17 00:00:00 2001 From: Joshua Williams Date: Tue, 21 Mar 2023 12:39:15 -0500 Subject: [PATCH 2/2] Help text for stop keywords --- utils.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils.cpp b/utils.cpp index a51e8c8e5eace..598ef74c13c69 100644 --- a/utils.cpp +++ b/utils.cpp @@ -105,6 +105,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); + fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);