From d6df8ecb9c1286fc945dd2bb2bd4068136ff26df Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 22 Apr 2024 06:45:01 +0200 Subject: [PATCH 01/10] refactor chat template api --- llama.cpp | 445 +++++++++++++++++++++++++++++++----------------------- llama.h | 18 +++ 2 files changed, 271 insertions(+), 192 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7440c740fefbc..16786bff3ff6c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17074,195 +17074,250 @@ static std::string trim(const std::string & str) { return str.substr(start, end - start); } -// Simple version of "llama_apply_chat_template" that only works with strings -// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. -static int32_t llama_chat_apply_template_internal( - const std::string & tmpl, - const std::vector & chat, - std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 - std::stringstream ss; - if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) { - // chatml template - for (auto message : chat) { - ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; +static int32_t llama_chat_get_model_template( + const struct llama_model * model, + const char * name, + char * buf, + int32_t length) { + GGML_ASSERT(model != nullptr); + auto get_meta = [&model](std::string template_key) { + // load template from model + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res < 0) { + return std::string(); // not found + } else { + return std::string(model_template.data(), model_template.size()); } - if (add_ass) { - ss << "<|im_start|>assistant\n"; + }; + std::string default_meta = "tokenizer.chat_template"; + std::string model_template; + if (name != nullptr) { + // support for named template: https://github.com/ggerganov/llama.cpp/pull/6588 + model_template = get_meta(std::string("tokenizer.chat_template.") + name); + if (model_template.empty()) { + model_template = get_meta(default_meta); } - } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) { - // llama2 template and its variants + } else { + // default template + model_template = get_meta(default_meta); + } + if (model_template.empty()) { + return -1; + } else { + snprintf(buf, length, "%s", model_template.c_str()); + return model_template.size() + 1; + } +} + +static llama_chat_template llama_chat_get_template_type(const char * tmpl) { + if (tmpl == nullptr) { + return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; + } + std::string stmpl(tmpl); + auto tmpl_contains = [&stmpl](std::string needle) { + return stmpl.find(needle) != std::string::npos; + }; + if (stmpl == "chatml" || tmpl_contains("<|im_start|>")) { + return LLAMA_CHAT_TEMPLATE_CHATML; + } else if (stmpl == "llama2" || tmpl_contains("[INST]")) { // [variant] support system message - bool support_system_message = tmpl.find("<>") != std::string::npos; - // [variant] space before + after response - bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; + bool support_system_message = tmpl_contains("<>"); // [variant] add BOS inside history - bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos; - // [variant] trim spaces from the input message - bool strip_message = tmpl.find("content.strip()") != std::string::npos; - // construct the prompt - bool is_inside_turn = true; // skip BOS at the beginning - ss << "[INST] "; - for (auto message : chat) { - std::string content = strip_message ? trim(message->content) : message->content; - std::string role(message->role); - if (!is_inside_turn) { - is_inside_turn = true; - ss << (add_bos_inside_history ? "[INST] " : "[INST] "); - } - if (role == "system") { - if (support_system_message) { - ss << "<>\n" << content << "\n<>\n\n"; - } else { - // if the model does not support system message, we still include it in the first message, but without <> - ss << content << "\n"; - } - } else if (role == "user") { - ss << content << " [/INST]"; + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + if (support_system_message && add_bos_inside_history) { + return LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS; + } else if (support_system_message) { + return LLAMA_CHAT_TEMPLATE_LLAMA2_SYS; + } else { + return LLAMA_CHAT_TEMPLATE_LLAMA2; + } + } else if (stmpl == "zephyr" || tmpl_contains("<|user|>")) { + return LLAMA_CHAT_TEMPLATE_ZEPHYR; + } else if (stmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { + return LLAMA_CHAT_TEMPLATE_MONARCH; + } else if (stmpl == "gemma" || tmpl_contains("")) { + return LLAMA_CHAT_TEMPLATE_GEMMA; + } else if (stmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + return LLAMA_CHAT_TEMPLATE_ORION; + } else if (stmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + return LLAMA_CHAT_TEMPLATE_OPENCHAT; + } else if (stmpl == "vicuna" || stmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + // [variant] support system message + if (stmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + return LLAMA_CHAT_TEMPLATE_VICUNA_ORCA; + } else { + return LLAMA_CHAT_TEMPLATE_VICUNA; + } + } else if (stmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + return LLAMA_CHAT_TEMPLATE_DEEPSEEK; + } else if (stmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { + return LLAMA_CHAT_TEMPLATE_COMMAND_R; + } else if (stmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { + return LLAMA_CHAT_TEMPLATE_LLAMA3; + } else { + // template not supported + return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; + } +} + +static int32_t llama_chat_get_prefix( + const llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length) { + std::stringstream ss; + std::string srole(role); + std::string sprev_role(prev_role == nullptr ? "" : prev_role); + auto str_toupper = [](std::string & str) { + std::string output(str); + for (size_t i = 0; i < output.size(); i++) { + output[i] = toupper(output[i]); + } + return output; + }; + switch (tmpl) { + case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: + return -1; + case LLAMA_CHAT_TEMPLATE_CHATML: + ss << "<|im_start|>" << srole << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_LLAMA2: + if (srole == "user") { + ss << "[INST] "; + } + break; + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: + if (!sprev_role.empty()) { + ss << ""; + } + // do not add "break" + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: + if (srole == "system") { + ss << "[INST]<>\n"; + } else if (srole == "user" && sprev_role != "system") { + ss << "[INST] "; + } + break; + case LLAMA_CHAT_TEMPLATE_ZEPHYR: + ss << "<|" << srole << "|>" << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_MONARCH: + { + std::string bos = sprev_role.empty() ? "" : ""; // skip BOS for first message + ss << bos << srole << "\n"; + } break; + case LLAMA_CHAT_TEMPLATE_GEMMA: + // for gemma, "assistant" is "model" + srole = srole == "assistant" ? "model" : srole; + ss << "" << srole << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_ORION: + // for orion, "user" is "human" + srole = srole == "user" ? "human" : srole; + srole[0] = toupper(srole[0]); // upper case for first letter + ss << srole << ": "; + break; + case LLAMA_CHAT_TEMPLATE_OPENCHAT: + if (srole == "system") { + ss << ""; } else { - ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; - is_inside_turn = false; - } - } - // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) { - // zephyr template - for (auto message : chat) { - ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; - } - if (add_ass) { - ss << "<|assistant|>\n"; - } - } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) { - // mlabonne/AlphaMonarch-7B template (the is included inside history) - for (auto message : chat) { - std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message - ss << bos << message->role << "\n" << message->content << "\n"; - } - if (add_ass) { - ss << "assistant\n"; - } - } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) { - // google/gemma-7b-it - std::string system_prompt = ""; - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken - system_prompt = trim(message->content); - continue; + srole[0] = toupper(srole[0]); // upper case for first letter + ss << "GPT4 Correct " << srole << ": "; } - // in gemma, "assistant" is "model" - role = role == "assistant" ? "model" : message->role; - ss << "" << role << "\n"; - if (!system_prompt.empty() && role != "model") { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } - ss << trim(message->content) << "\n"; - } - if (add_ass) { - ss << "model\n"; - } - } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) { - // OrionStarAI/Orion-14B-Chat - std::string system_prompt = ""; - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // there is no system message support, we will merge it with user prompt - system_prompt = message->content; - continue; - } else if (role == "user") { - ss << "Human: "; - if (!system_prompt.empty()) { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } - ss << message->content << "\n\nAssistant: "; + break; + case LLAMA_CHAT_TEMPLATE_VICUNA: + case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: + // TODO: original vicuna template does not support system message + ss << str_toupper(srole) << ": "; + break; + case LLAMA_CHAT_TEMPLATE_DEEPSEEK: + if (srole == "user") { + ss << "### Instruction:\n"; } else { - ss << message->content << ""; + ss << "### Response:\n"; } - } - } else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) { - // openchat/openchat-3.5-0106, - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << message->content << "<|end_of_turn|>"; + break; + case LLAMA_CHAT_TEMPLATE_COMMAND_R: + // for command-r, "assistant" is "chatbot" + srole = srole == "assistant" ? "chatbot" : srole; + ss << "<|START_OF_TURN_TOKEN|><|" << str_toupper(srole) << "_TOKEN|>"; + break; + case LLAMA_CHAT_TEMPLATE_LLAMA3: + ss << "<|start_header_id|>" << srole << "<|end_header_id|>\n\n"; + break; + } + std::string output = ss.str(); + snprintf(buf, length, "%s", output.c_str()); + return output.size() + 1; +} + +static int32_t llama_chat_get_postfix( + const llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length) { + std::stringstream ss; + std::string srole(role); + std::string sprev_role(prev_role == nullptr ? "" : prev_role); + switch (tmpl) { + case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: + return -1; + case LLAMA_CHAT_TEMPLATE_CHATML: + ss << "<|im_end|>\n"; + break; + case LLAMA_CHAT_TEMPLATE_LLAMA2: + if (srole == "user") { + ss << " [/INST]"; + } + break; + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: + if (srole == "system") { + ss << "\n<>\n\n"; } else { - role[0] = toupper(role[0]); - ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; + ss << ""; } - } - if (add_ass) { - ss << "GPT4 Correct Assistant:"; - } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) { - // eachadea/vicuna-13b-1.1 (and Orca variant) - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) { - ss << "SYSTEM: " << message->content << "\n"; - } else { - ss << message->content << "\n\n"; - } - } else if (role == "user") { - ss << "USER: " << message->content << "\n"; - } else if (role == "assistant") { - ss << "ASSISTANT: " << message->content << "\n"; - } - } - if (add_ass) { - ss << "ASSISTANT:"; - } - } else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) { - // deepseek-ai/deepseek-coder-33b-instruct - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << message->content; - } else if (role == "user") { - ss << "### Instruction:\n" << message->content << "\n"; - } else if (role == "assistant") { - ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; - } - } - if (add_ass) { - ss << "### Response:\n"; - } - } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) { - // CohereForAI/c4ai-command-r-plus - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } else if (role == "user") { - ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } else if (role == "assistant") { - ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } - } - if (add_ass) { - ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; - } - } else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) { - // Llama 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; - } - if (add_ass) { - ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; - } - } else { - // template not supported - return -1; + break; + case LLAMA_CHAT_TEMPLATE_ZEPHYR: + ss << "<|endoftext|>" << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_MONARCH: + ss << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_GEMMA: + ss << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_ORION: + ss << ""; + break; + case LLAMA_CHAT_TEMPLATE_OPENCHAT: + srole[0] = toupper(srole[0]); + ss << "<|end_of_turn|>"; + break; + case LLAMA_CHAT_TEMPLATE_VICUNA: + case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: + ss << "\n"; + break; + case LLAMA_CHAT_TEMPLATE_DEEPSEEK: + if (srole == "user") { + ss << "\n"; + } else { + ss << "\n<|EOT|>\n"; + } + break; + case LLAMA_CHAT_TEMPLATE_COMMAND_R: + ss << "<|END_OF_TURN_TOKEN|>"; + break; + case LLAMA_CHAT_TEMPLATE_LLAMA3: + ss << "<|eot_id|>"; + break; } - dest = ss.str(); - return dest.size(); + std::string output = ss.str(); + snprintf(buf, length, "%s", output.c_str()); + return output.size() + 1; } LLAMA_API int32_t llama_chat_apply_template( @@ -17275,11 +17330,8 @@ LLAMA_API int32_t llama_chat_apply_template( int32_t length) { std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { - GGML_ASSERT(model != nullptr); - // load template from model std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + int32_t res = llama_chat_get_model_template(model, nullptr, model_template.data(), model_template.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default curr_tmpl = "chatml"; // see llama_chat_apply_template_internal @@ -17288,22 +17340,31 @@ LLAMA_API int32_t llama_chat_apply_template( } } + // detect template type + llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str()); + if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { + return -1; + } + // format the chat to string - std::vector chat_vec; - chat_vec.resize(n_msg); + std::stringstream ss; + std::string prev_role; + std::vector prefix(1024, 0); + std::vector postfix(1024, 0); for (size_t i = 0; i < n_msg; i++) { - chat_vec[i] = &chat[i]; + std::string role(chat[i].role); + std::string content(chat[i].content); + llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size()); + llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size()); + ss << std::string(prefix.data(), prefix.size()) << content << std::string(postfix.data(), postfix.size()); + prev_role = role; } - std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); - if (res < 0) { - return res; - } + std::string output = ss.str(); if (buf && length > 0) { - strncpy(buf, formatted_chat.c_str(), length); + snprintf(buf, length, "%s", output.c_str()); } - return res; + return output.size() + 1; } LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { diff --git a/llama.h b/llama.h index 4effca42cc65d..603bfe99fbec7 100644 --- a/llama.h +++ b/llama.h @@ -147,6 +147,24 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; + enum llama_chat_template { + LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED = 0, + LLAMA_CHAT_TEMPLATE_CHATML = 1, // Example: teknium/OpenHermes-2.5-Mistral-7B + LLAMA_CHAT_TEMPLATE_LLAMA2 = 2, // Original llama2 template (no <> support) + LLAMA_CHAT_TEMPLATE_LLAMA2_SYS = 3, // <> support (example: bofenghuang/vigogne-2-70b-chat) + LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS = 4, // <> support with BOS inside history (example: TomGrc/FusionNet_34Bx2_MoE) + LLAMA_CHAT_TEMPLATE_ZEPHYR = 5, // Example: HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1 + LLAMA_CHAT_TEMPLATE_MONARCH = 6, // Example: mlabonne/AlphaMonarch-7B + LLAMA_CHAT_TEMPLATE_GEMMA = 7, // Example: google/gemma-7b-it + LLAMA_CHAT_TEMPLATE_ORION = 8, // Example: OrionStarAI/Orion-14B-Chat + LLAMA_CHAT_TEMPLATE_OPENCHAT = 9, // Example: openchat/openchat-3.5-0106 + LLAMA_CHAT_TEMPLATE_VICUNA = 10, // Example: NousResearch/Nous-Capybara-34B + LLAMA_CHAT_TEMPLATE_VICUNA_ORCA = 11, // Variant of vicuna that supports system role + LLAMA_CHAT_TEMPLATE_DEEPSEEK = 12, // Example: deepseek-ai/deepseek-coder-33b-instruct + LLAMA_CHAT_TEMPLATE_COMMAND_R = 13, // Example: CohereForAI/c4ai-command-r-plus + LLAMA_CHAT_TEMPLATE_LLAMA3 = 14, // Example: meta-llama/Meta-Llama-3-8B-Instruct + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token From 98c46cfbfa57483e3788fb9189f3cecc7beb6229 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 22 Apr 2024 08:49:57 +0200 Subject: [PATCH 02/10] fix llama_chat_apply_template --- llama.cpp | 95 +++++++++++++++++++++++++++++------- tests/test-chat-template.cpp | 24 ++++----- 2 files changed, 89 insertions(+), 30 deletions(-) diff --git a/llama.cpp b/llama.cpp index 16786bff3ff6c..436082acabd0b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17177,6 +17177,11 @@ static int32_t llama_chat_get_prefix( } return output; }; + auto str_tofirstcap = [](std::string & str) { + std::string output(str); + output[0] = toupper(output[0]); + return output; + }; switch (tmpl) { case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: return -1; @@ -17189,13 +17194,20 @@ static int32_t llama_chat_get_prefix( } break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: - if (!sprev_role.empty()) { - ss << ""; + if (srole == "system") { + ss << "[INST] <>\n"; + } else if (srole == "user" && sprev_role != "system") { + if (!sprev_role.empty()) { + ss << ""; + } + ss << "[INST] "; + } else if (srole == "assistant") { + ss << " "; } - // do not add "break" + break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: if (srole == "system") { - ss << "[INST]<>\n"; + ss << "[INST] <>\n"; } else if (srole == "user" && sprev_role != "system") { ss << "[INST] "; } @@ -17216,15 +17228,16 @@ static int32_t llama_chat_get_prefix( case LLAMA_CHAT_TEMPLATE_ORION: // for orion, "user" is "human" srole = srole == "user" ? "human" : srole; - srole[0] = toupper(srole[0]); // upper case for first letter - ss << srole << ": "; + ss << str_tofirstcap(srole) << ": "; + if (srole == "assistant") { + ss << ""; + } break; case LLAMA_CHAT_TEMPLATE_OPENCHAT: if (srole == "system") { ss << ""; } else { - srole[0] = toupper(srole[0]); // upper case for first letter - ss << "GPT4 Correct " << srole << ": "; + ss << "GPT4 Correct " << str_tofirstcap(srole) << ": "; } break; case LLAMA_CHAT_TEMPLATE_VICUNA: @@ -17250,7 +17263,7 @@ static int32_t llama_chat_get_prefix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size() + 1; + return output.size(); } static int32_t llama_chat_get_postfix( @@ -17271,14 +17284,18 @@ static int32_t llama_chat_get_postfix( case LLAMA_CHAT_TEMPLATE_LLAMA2: if (srole == "user") { ss << " [/INST]"; + } else if (srole == "assistant") { + ss << ""; } break; case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: if (srole == "system") { ss << "\n<>\n\n"; - } else { - ss << ""; + } else if (srole == "user") { + ss << " [/INST]"; + } else if (srole == "assistant") { + ss << " "; } break; case LLAMA_CHAT_TEMPLATE_ZEPHYR: @@ -17291,7 +17308,11 @@ static int32_t llama_chat_get_postfix( ss << "\n"; break; case LLAMA_CHAT_TEMPLATE_ORION: - ss << ""; + if (srole == "assistant") { + ss << ""; + } else { + ss << "\n\n"; + } break; case LLAMA_CHAT_TEMPLATE_OPENCHAT: srole[0] = toupper(srole[0]); @@ -17299,7 +17320,11 @@ static int32_t llama_chat_get_postfix( break; case LLAMA_CHAT_TEMPLATE_VICUNA: case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: - ss << "\n"; + if (srole == "assistant") { + ss << "\n"; + } else { + ss << "\n"; + } break; case LLAMA_CHAT_TEMPLATE_DEEPSEEK: if (srole == "user") { @@ -17317,7 +17342,25 @@ static int32_t llama_chat_get_postfix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size() + 1; + return output.size(); +} + +static bool llama_chat_support_system_message(const llama_chat_template tmpl) { + switch (tmpl) { + case LLAMA_CHAT_TEMPLATE_CHATML: + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: + case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: + case LLAMA_CHAT_TEMPLATE_ZEPHYR: + case LLAMA_CHAT_TEMPLATE_MONARCH: + case LLAMA_CHAT_TEMPLATE_ORION: + case LLAMA_CHAT_TEMPLATE_OPENCHAT: + case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: + case LLAMA_CHAT_TEMPLATE_COMMAND_R: + case LLAMA_CHAT_TEMPLATE_LLAMA3: + return true; + default: + return false; + } } LLAMA_API int32_t llama_chat_apply_template( @@ -17342,6 +17385,7 @@ LLAMA_API int32_t llama_chat_apply_template( // detect template type llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str()); + bool support_system_message = llama_chat_support_system_message(ttmpl); if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { return -1; } @@ -17349,22 +17393,37 @@ LLAMA_API int32_t llama_chat_apply_template( // format the chat to string std::stringstream ss; std::string prev_role; - std::vector prefix(1024, 0); - std::vector postfix(1024, 0); for (size_t i = 0; i < n_msg; i++) { std::string role(chat[i].role); std::string content(chat[i].content); + if (!support_system_message) { + // if the template does not support system message, we convert it to user message + role = role == "system" ? "user" : role; + } + std::vector prefix(1024, 0); + std::vector postfix(1024, 0); llama_chat_get_prefix(ttmpl, role.c_str(), prev_role.c_str(), prefix.data(), prefix.size()); llama_chat_get_postfix(ttmpl, role.c_str(), prev_role.c_str(), postfix.data(), postfix.size()); - ss << std::string(prefix.data(), prefix.size()) << content << std::string(postfix.data(), postfix.size()); + ss << std::string(prefix.data()) << trim(content) << std::string(postfix.data()); prev_role = role; } + if (add_ass) { + std::vector prefix(1024, 0); + llama_chat_get_prefix(ttmpl, "assistant", prev_role.c_str(), prefix.data(), prefix.size()); + std::string assistant_prompt(prefix.data()); + if (assistant_prompt.back() == ' ') { + // Some templates need trailing space to be tokenized with the next word. We should make sure there is no trailing in the output text + assistant_prompt.pop_back(); + } + ss << assistant_prompt; + } + std::string output = ss.str(); if (buf && length > 0) { snprintf(buf, length, "%s", output.c_str()); } - return output.size() + 1; + return output.size(); } LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index cddf86a4105ea..c2db07257eaf0 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -14,7 +14,7 @@ int main(void) { {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "Who are you"}, - {"assistant", " I am an assistant "}, + {"assistant", "I am an assistant"}, {"user", "Another question"}, }; size_t message_count = 6; @@ -52,27 +52,27 @@ int main(void) { }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\nI am an assistant<|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", // mistralai/Mistral-7B-Instruct-v0.2 - "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] You are a helpful assistant [/INST][INST] Hello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", // TheBloke/FusionNet_34Bx2_MoE-AWQ - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // bofenghuang/vigogne-2-70b-chat - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there [INST] Who are you [/INST]I am an assistant [INST] Another question [/INST]", // mlabonne/AlphaMonarch-7B - "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\nI am an assistant\nuser\nAnother question\nassistant\n", // google/gemma-7b-it - "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + "user\nYou are a helpful assistant\nuser\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", // OrionStarAI/Orion-14B-Chat - "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + "System: You are a helpful assistant\n\nHuman: Hello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistantHuman: Another question\n\nAssistant: ", // openchat/openchat-3.5-0106 - "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant<|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", // deepseek-ai/deepseek-coder-33b-instruct - "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + "### Instruction:\nYou are a helpful assistant\n### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\nI am an assistant\n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", // eachadea/vicuna-13b-1.1 - "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + "USER: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // Orca-Vicuna - "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // CohereForAI/c4ai-command-r-plus "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", // Llama 3 From a202b561270613be74aa42d118db367c5c49bd60 Mon Sep 17 00:00:00 2001 From: ngxson Date: Mon, 22 Apr 2024 09:04:24 +0200 Subject: [PATCH 03/10] add header --- llama.h | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/llama.h b/llama.h index 603bfe99fbec7..9f72834a138f6 100644 --- a/llama.h +++ b/llama.h @@ -854,6 +854,10 @@ extern "C" { int32_t length, bool special); + // + // Chat template + // + /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template @@ -873,6 +877,54 @@ extern "C" { char * buf, int32_t length); + /// Get the Jinja model saved inside given model + /// @param model The pointer to llama_model + /// @param name Template name (can be a nullptr for default template). See: https://github.com/ggerganov/llama.cpp/pull/6588 + /// @param buf The output buffer + /// @param length The size of the allocated buffer + /// @return The total number of bytes of the template. If a named template cannot be found, it will use default template. If no template can be found, it returns -1 + LLAMA_API int32_t llama_chat_get_model_template( + const struct llama_model * model, + const char * name, + char * buf, + int32_t length); + + /// Get the enum llama_chat_template based on Jinja template + /// @param tmpl Jinja template (a string) + /// @return The currect enum llama_chat_template + LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl); + + /// Get the format prefix for a given message + /// @param tmpl Use enum llama_chat_template + /// @param role The role of the current message + /// @param prev_role The role of the previous message, can be nullptr + /// @param buf The output buffer + /// @param length The size of the allocated buffer + /// @return The total number of bytes of the output string + LLAMA_API int32_t llama_chat_get_prefix( + const llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); + + /// Get the format postfix for a given message + /// @param tmpl Use enum llama_chat_template + /// @param role The role of the current message + /// @param prev_role The role of the previous message, can be nullptr + /// @param buf The output buffer + /// @param length The size of the allocated buffer + /// @return The total number of bytes of the output string + LLAMA_API int32_t llama_chat_get_postfix( + const llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); + + /// Check if a given template support system message or not + LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl); + // // Grammar // From 588b72d9500a0c29dc5ad73f74e7f2bb24d0b4fe Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 15:20:21 +0200 Subject: [PATCH 04/10] fix templates not support system message --- llama.cpp | 24 +++++++++++++++--------- tests/test-chat-template.cpp | 8 ++++---- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/llama.cpp b/llama.cpp index 436082acabd0b..76e0367b87db7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17074,7 +17074,7 @@ static std::string trim(const std::string & str) { return str.substr(start, end - start); } -static int32_t llama_chat_get_model_template( +LLAMA_API int32_t llama_chat_get_model_template( const struct llama_model * model, const char * name, char * buf, @@ -17110,7 +17110,7 @@ static int32_t llama_chat_get_model_template( } } -static llama_chat_template llama_chat_get_template_type(const char * tmpl) { +LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) { if (tmpl == nullptr) { return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; } @@ -17161,7 +17161,7 @@ static llama_chat_template llama_chat_get_template_type(const char * tmpl) { } } -static int32_t llama_chat_get_prefix( +LLAMA_API int32_t llama_chat_get_prefix( const llama_chat_template tmpl, const char * role, const char * prev_role, @@ -17266,7 +17266,7 @@ static int32_t llama_chat_get_prefix( return output.size(); } -static int32_t llama_chat_get_postfix( +LLAMA_API int32_t llama_chat_get_postfix( const llama_chat_template tmpl, const char * role, const char * prev_role, @@ -17345,7 +17345,7 @@ static int32_t llama_chat_get_postfix( return output.size(); } -static bool llama_chat_support_system_message(const llama_chat_template tmpl) { +LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl) { switch (tmpl) { case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: @@ -17377,7 +17377,7 @@ LLAMA_API int32_t llama_chat_apply_template( int32_t res = llama_chat_get_model_template(model, nullptr, model_template.data(), model_template.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "chatml"; // see llama_chat_apply_template_internal + curr_tmpl = "chatml"; } else { curr_tmpl = std::string(model_template.data(), model_template.size()); } @@ -17393,12 +17393,18 @@ LLAMA_API int32_t llama_chat_apply_template( // format the chat to string std::stringstream ss; std::string prev_role; + bool merge_system_message = false; for (size_t i = 0; i < n_msg; i++) { std::string role(chat[i].role); std::string content(chat[i].content); - if (!support_system_message) { - // if the template does not support system message, we convert it to user message - role = role == "system" ? "user" : role; + // if the template does not support system message, we merge it with the next message + if (role == "system" && !support_system_message) { + merge_system_message = true; + continue; + } + if (merge_system_message && i > 0) { + content = std::string(chat[i - 1].content) + "\n\n" + content; + merge_system_message = false; } std::vector prefix(1024, 0); std::vector postfix(1024, 0); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index c2db07257eaf0..ebdb84bd5578b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -54,7 +54,7 @@ int main(void) { // teknium/OpenHermes-2.5-Mistral-7B "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\nI am an assistant<|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", // mistralai/Mistral-7B-Instruct-v0.2 - "[INST] You are a helpful assistant [/INST][INST] Hello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + "[INST] You are a helpful assistant\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", // TheBloke/FusionNet_34Bx2_MoE-AWQ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // bofenghuang/vigogne-2-70b-chat @@ -62,15 +62,15 @@ int main(void) { // mlabonne/AlphaMonarch-7B "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\nI am an assistant\nuser\nAnother question\nassistant\n", // google/gemma-7b-it - "user\nYou are a helpful assistant\nuser\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", // OrionStarAI/Orion-14B-Chat "System: You are a helpful assistant\n\nHuman: Hello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistantHuman: Another question\n\nAssistant: ", // openchat/openchat-3.5-0106 "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant<|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", // deepseek-ai/deepseek-coder-33b-instruct - "### Instruction:\nYou are a helpful assistant\n### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\nI am an assistant\n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + "### Instruction:\nYou are a helpful assistant\n\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\nI am an assistant\n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", // eachadea/vicuna-13b-1.1 - "USER: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", + "USER: You are a helpful assistant\n\nHello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // Orca-Vicuna "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant\nUSER: Another question\nASSISTANT:", // CohereForAI/c4ai-command-r-plus From f1a93548aac30dfe17af467ed08d1a64928627d3 Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 15:31:46 +0200 Subject: [PATCH 05/10] clean up --- examples/server/utils.hpp | 2 +- llama.cpp | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1a22125028204..a7102b296a2d2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -149,7 +149,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); } - const std::string formatted_chat(buf.data(), res); + const std::string formatted_chat(buf.data()); LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); diff --git a/llama.cpp b/llama.cpp index 76e0367b87db7..a42235234b377 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17106,7 +17106,7 @@ LLAMA_API int32_t llama_chat_get_model_template( return -1; } else { snprintf(buf, length, "%s", model_template.c_str()); - return model_template.size() + 1; + return model_template.size(); } } @@ -17162,7 +17162,7 @@ LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) { } LLAMA_API int32_t llama_chat_get_prefix( - const llama_chat_template tmpl, + const llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17170,6 +17170,7 @@ LLAMA_API int32_t llama_chat_get_prefix( std::stringstream ss; std::string srole(role); std::string sprev_role(prev_role == nullptr ? "" : prev_role); + // str_toupper converts a string to all upper case, example: "abc" ==> "ABC" auto str_toupper = [](std::string & str) { std::string output(str); for (size_t i = 0; i < output.size(); i++) { @@ -17177,12 +17178,14 @@ LLAMA_API int32_t llama_chat_get_prefix( } return output; }; + // str_tofirstcap transforms first letter to uppercase, example: "abc" ==> "Abc" auto str_tofirstcap = [](std::string & str) { std::string output(str); output[0] = toupper(output[0]); return output; }; - switch (tmpl) { + // ttmpl means "typed template" + switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: return -1; case LLAMA_CHAT_TEMPLATE_CHATML: @@ -17267,7 +17270,7 @@ LLAMA_API int32_t llama_chat_get_prefix( } LLAMA_API int32_t llama_chat_get_postfix( - const llama_chat_template tmpl, + const llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17275,7 +17278,7 @@ LLAMA_API int32_t llama_chat_get_postfix( std::stringstream ss; std::string srole(role); std::string sprev_role(prev_role == nullptr ? "" : prev_role); - switch (tmpl) { + switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: return -1; case LLAMA_CHAT_TEMPLATE_CHATML: @@ -17345,8 +17348,8 @@ LLAMA_API int32_t llama_chat_get_postfix( return output.size(); } -LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl) { - switch (tmpl) { +LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl) { + switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS: @@ -17371,6 +17374,8 @@ LLAMA_API int32_t llama_chat_apply_template( bool add_ass, char * buf, int32_t length) { + // either model or tmpl must be given + GGML_ASSERT(model != nullptr || tmpl != nullptr); std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { std::vector model_template(2048, 0); // longest known template is about 1200 bytes From 3222b4b8e55a81e55c6b60cf2f49295e5a7ac867 Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 18:06:41 +0200 Subject: [PATCH 06/10] add guide for adding template --- llama.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llama.cpp b/llama.cpp index a42235234b377..185bbcfd3a856 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17185,6 +17185,7 @@ LLAMA_API int32_t llama_chat_get_prefix( return output; }; // ttmpl means "typed template" + // before adding a new template, please see the guide here: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template#how-to-add-a-new-template switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED: return -1; From 81b5903890a239703f75f96f5659d198ca8d31f8 Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 18:18:12 +0200 Subject: [PATCH 07/10] adapt phi3 template --- llama.cpp | 18 +++++++++--------- llama.h | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d1676983da366..2d867bf3b5ec2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17295,15 +17295,6 @@ LLAMA_API int32_t llama_chat_get_model_template( if (model_template.empty()) { model_template = get_meta(default_meta); } - } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) { - // Phi 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|" << role << "|>\n" << trim(message->content) << "<|end|>\n"; - } - if (add_ass) { - ss << "<|assistant|>\n"; - } } else { // default template model_template = get_meta(default_meta); @@ -17361,6 +17352,8 @@ LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) { return LLAMA_CHAT_TEMPLATE_COMMAND_R; } else if (stmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { return LLAMA_CHAT_TEMPLATE_LLAMA3; + } else if (stmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + return LLAMA_CHAT_TEMPLATE_PHI3; } else { // template not supported return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; @@ -17470,6 +17463,9 @@ LLAMA_API int32_t llama_chat_get_prefix( case LLAMA_CHAT_TEMPLATE_LLAMA3: ss << "<|start_header_id|>" << srole << "<|end_header_id|>\n\n"; break; + case LLAMA_CHAT_TEMPLATE_PHI3: + ss << "<|" << srole << "|>\n"; + break; } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); @@ -17549,6 +17545,9 @@ LLAMA_API int32_t llama_chat_get_postfix( case LLAMA_CHAT_TEMPLATE_LLAMA3: ss << "<|eot_id|>"; break; + case LLAMA_CHAT_TEMPLATE_PHI3: + ss << "<|end|>\n"; + break; } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); @@ -17567,6 +17566,7 @@ LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl case LLAMA_CHAT_TEMPLATE_VICUNA_ORCA: case LLAMA_CHAT_TEMPLATE_COMMAND_R: case LLAMA_CHAT_TEMPLATE_LLAMA3: + case LLAMA_CHAT_TEMPLATE_PHI3: return true; default: return false; diff --git a/llama.h b/llama.h index 781fbc279efc8..f89301644e043 100644 --- a/llama.h +++ b/llama.h @@ -163,6 +163,7 @@ extern "C" { LLAMA_CHAT_TEMPLATE_DEEPSEEK = 12, // Example: deepseek-ai/deepseek-coder-33b-instruct LLAMA_CHAT_TEMPLATE_COMMAND_R = 13, // Example: CohereForAI/c4ai-command-r-plus LLAMA_CHAT_TEMPLATE_LLAMA3 = 14, // Example: meta-llama/Meta-Llama-3-8B-Instruct + LLAMA_CHAT_TEMPLATE_PHI3 = 15, // Example: microsoft/Phi-3-mini-128k-instruct }; typedef struct llama_token_data { From 0d3363e4e645066c324096f3fc9bc1a7d56a268a Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 18:27:39 +0200 Subject: [PATCH 08/10] llama_chat_get_typed_template --- llama.cpp | 4 ++-- llama.h | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2d867bf3b5ec2..400713a5de3e6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17307,7 +17307,7 @@ LLAMA_API int32_t llama_chat_get_model_template( } } -LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl) { +LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl) { if (tmpl == nullptr) { return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; } @@ -17596,7 +17596,7 @@ LLAMA_API int32_t llama_chat_apply_template( } // detect template type - llama_chat_template ttmpl = llama_chat_get_template_type(curr_tmpl.c_str()); + llama_chat_template ttmpl = llama_chat_get_typed_template(curr_tmpl.c_str()); bool support_system_message = llama_chat_support_system_message(ttmpl); if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { return -1; diff --git a/llama.h b/llama.h index f89301644e043..4382dbe5812ad 100644 --- a/llama.h +++ b/llama.h @@ -892,12 +892,12 @@ extern "C" { char * buf, int32_t length); - /// Get the enum llama_chat_template based on Jinja template + /// Get the value of enum llama_chat_template based on given Jinja template /// @param tmpl Jinja template (a string) - /// @return The currect enum llama_chat_template - LLAMA_API llama_chat_template llama_chat_get_template_type(const char * tmpl); + /// @return The correct value of enum llama_chat_template + LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl); - /// Get the format prefix for a given message + /// Get the format prefix for a given message (based on role) /// @param tmpl Use enum llama_chat_template /// @param role The role of the current message /// @param prev_role The role of the previous message, can be nullptr @@ -911,7 +911,7 @@ extern "C" { char * buf, int32_t length); - /// Get the format postfix for a given message + /// Get the format postfix for a given message (based on role) /// @param tmpl Use enum llama_chat_template /// @param role The role of the current message /// @param prev_role The role of the previous message, can be nullptr From 7f89803536193053e2214840f06c447319b133dd Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 18:38:20 +0200 Subject: [PATCH 09/10] add enum keyword --- llama.cpp | 10 +++++----- llama.h | 24 ++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/llama.cpp b/llama.cpp index 400713a5de3e6..d1bbf7f8566ef 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17307,7 +17307,7 @@ LLAMA_API int32_t llama_chat_get_model_template( } } -LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl) { +LLAMA_API enum llama_chat_template llama_chat_get_typed_template(const char * tmpl) { if (tmpl == nullptr) { return LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED; } @@ -17361,7 +17361,7 @@ LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl) { } LLAMA_API int32_t llama_chat_get_prefix( - const llama_chat_template ttmpl, + const enum llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17473,7 +17473,7 @@ LLAMA_API int32_t llama_chat_get_prefix( } LLAMA_API int32_t llama_chat_get_postfix( - const llama_chat_template ttmpl, + const enum llama_chat_template ttmpl, const char * role, const char * prev_role, char * buf, @@ -17554,7 +17554,7 @@ LLAMA_API int32_t llama_chat_get_postfix( return output.size(); } -LLAMA_API bool llama_chat_support_system_message(const llama_chat_template ttmpl) { +LLAMA_API bool llama_chat_support_system_message(const enum llama_chat_template ttmpl) { switch (ttmpl) { case LLAMA_CHAT_TEMPLATE_CHATML: case LLAMA_CHAT_TEMPLATE_LLAMA2_SYS_BOS: @@ -17596,7 +17596,7 @@ LLAMA_API int32_t llama_chat_apply_template( } // detect template type - llama_chat_template ttmpl = llama_chat_get_typed_template(curr_tmpl.c_str()); + enum llama_chat_template ttmpl = llama_chat_get_typed_template(curr_tmpl.c_str()); bool support_system_message = llama_chat_support_system_message(ttmpl); if (ttmpl == LLAMA_CHAT_TEMPLATE_NOT_SUPPORTED) { return -1; diff --git a/llama.h b/llama.h index 4382dbe5812ad..6705ca3236c17 100644 --- a/llama.h +++ b/llama.h @@ -895,7 +895,7 @@ extern "C" { /// Get the value of enum llama_chat_template based on given Jinja template /// @param tmpl Jinja template (a string) /// @return The correct value of enum llama_chat_template - LLAMA_API llama_chat_template llama_chat_get_typed_template(const char * tmpl); + LLAMA_API enum llama_chat_template llama_chat_get_typed_template(const char * tmpl); /// Get the format prefix for a given message (based on role) /// @param tmpl Use enum llama_chat_template @@ -905,11 +905,11 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the output string LLAMA_API int32_t llama_chat_get_prefix( - const llama_chat_template tmpl, - const char * role, - const char * prev_role, - char * buf, - int32_t length); + const enum llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); /// Get the format postfix for a given message (based on role) /// @param tmpl Use enum llama_chat_template @@ -919,14 +919,14 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the output string LLAMA_API int32_t llama_chat_get_postfix( - const llama_chat_template tmpl, - const char * role, - const char * prev_role, - char * buf, - int32_t length); + const enum llama_chat_template tmpl, + const char * role, + const char * prev_role, + char * buf, + int32_t length); /// Check if a given template support system message or not - LLAMA_API bool llama_chat_support_system_message(const llama_chat_template tmpl); + LLAMA_API bool llama_chat_support_system_message(const enum llama_chat_template tmpl); // // Grammar From 476d319fde0ae6c6a2ed9cfe54e548ad812fe5a5 Mon Sep 17 00:00:00 2001 From: ngxson Date: Wed, 24 Apr 2024 19:41:51 +0200 Subject: [PATCH 10/10] correct buffer size --- llama.cpp | 8 ++++---- llama.h | 8 ++++---- tests/test-chat-template.cpp | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index d1bbf7f8566ef..676c085da1638 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17303,7 +17303,7 @@ LLAMA_API int32_t llama_chat_get_model_template( return -1; } else { snprintf(buf, length, "%s", model_template.c_str()); - return model_template.size(); + return model_template.size() + 1; } } @@ -17469,7 +17469,7 @@ LLAMA_API int32_t llama_chat_get_prefix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size(); + return output.size() + 1; } LLAMA_API int32_t llama_chat_get_postfix( @@ -17551,7 +17551,7 @@ LLAMA_API int32_t llama_chat_get_postfix( } std::string output = ss.str(); snprintf(buf, length, "%s", output.c_str()); - return output.size(); + return output.size() + 1; } LLAMA_API bool llama_chat_support_system_message(const enum llama_chat_template ttmpl) { @@ -17641,7 +17641,7 @@ LLAMA_API int32_t llama_chat_apply_template( if (buf && length > 0) { snprintf(buf, length, "%s", output.c_str()); } - return output.size(); + return output.size() + 1; } LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { diff --git a/llama.h b/llama.h index 6705ca3236c17..cc25869a05808 100644 --- a/llama.h +++ b/llama.h @@ -870,7 +870,7 @@ extern "C" { /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// @param length The size of the allocated buffer - /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. + /// @return The total number of bytes of the formatted prompt (null terminator included). If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * tmpl, @@ -885,7 +885,7 @@ extern "C" { /// @param name Template name (can be a nullptr for default template). See: https://github.com/ggerganov/llama.cpp/pull/6588 /// @param buf The output buffer /// @param length The size of the allocated buffer - /// @return The total number of bytes of the template. If a named template cannot be found, it will use default template. If no template can be found, it returns -1 + /// @return The total number of bytes of the template (null terminator included). If a named template cannot be found, it will use default template. If no template can be found, it returns -1 LLAMA_API int32_t llama_chat_get_model_template( const struct llama_model * model, const char * name, @@ -903,7 +903,7 @@ extern "C" { /// @param prev_role The role of the previous message, can be nullptr /// @param buf The output buffer /// @param length The size of the allocated buffer - /// @return The total number of bytes of the output string + /// @return The total number of bytes of the output string (null terminator included) LLAMA_API int32_t llama_chat_get_prefix( const enum llama_chat_template tmpl, const char * role, @@ -917,7 +917,7 @@ extern "C" { /// @param prev_role The role of the previous message, can be nullptr /// @param buf The output buffer /// @param length The size of the allocated buffer - /// @return The total number of bytes of the output string + /// @return The total number of bytes of the output string (null terminator included) LLAMA_API int32_t llama_chat_get_postfix( const enum llama_chat_template tmpl, const char * role, diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 9f1f9e5054b0d..c22f8d838fb2f 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -103,7 +103,7 @@ int main(void) { formatted_chat.size() ); formatted_chat.resize(res); - std::string output(formatted_chat.data(), formatted_chat.size()); + std::string output(formatted_chat.data()); std::cout << output << "\n-------------------------\n"; assert(output == expected); }