From 1841ca0ae0b2a799ca55cb3b64f6d9bb1aac8990 Mon Sep 17 00:00:00 2001 From: Christopher Oezbek Date: Sun, 15 Oct 2023 22:23:33 +0200 Subject: [PATCH 1/2] Fixed loadPrompt() when prompt length exceeds context. --- examples/server/server.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ee0ababb1d5ce..be651bddc27c4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -118,7 +118,7 @@ static void server_log(const char *level, const char *function, int line, } const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); - printf("%.*s\n", (int)str.size(), str.data()); + printf("%.*s\n\n", (int)str.size(), str.data()); fflush(stdout); } @@ -436,31 +436,34 @@ struct llama_server_context } params.n_keep = std::min(n_ctx - 4, params.n_keep); - // if input prompt is too big, truncate like normal + // if input prompt is too big, we will truncate in the same way when the embd becomes too big when generating tokens if (num_prompt_tokens >= (size_t)n_ctx) { - const int n_left = (n_ctx - params.n_keep) / 2; + const int n_left = n_ctx - params.n_keep; + + // Keep n_keep tokens of start of prompt (at most n_ctx - 4) std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + // Use half the left-over space in the context for the prompt + new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end()); LOG_VERBOSE("input truncated", { {"n_ctx", n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + {"num_prompt_tokens", new_tokens.size()} }); truncated = true; prompt_tokens = new_tokens; + num_prompt_tokens = prompt_tokens.size(); } - else - { - const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); - } + + // Initialize last_n_tokens + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); // compare the evaluated prompt with the new prompt n_past = common_part(embd, prompt_tokens); From 5becac802f0caca86e43459a95a510c6a7b8a579 Mon Sep 17 00:00:00 2001 From: coezbek Date: Sun, 15 Oct 2023 22:46:03 +0200 Subject: [PATCH 2/2] Fix formatting --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index be651bddc27c4..b1a6aed274d20 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -118,7 +118,7 @@ static void server_log(const char *level, const char *function, int line, } const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); - printf("%.*s\n\n", (int)str.size(), str.data()); + printf("%.*s\n", (int)str.size(), str.data()); fflush(stdout); } @@ -460,7 +460,7 @@ struct llama_server_context num_prompt_tokens = prompt_tokens.size(); } - // Initialize last_n_tokens + // Initialize last_n_tokens const size_t ps = num_prompt_tokens; std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);