From 4e644408a56a75e61bfcb6a8419efc938ab14e3b Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 16 Feb 2024 16:27:05 +0100 Subject: [PATCH 1/8] llama: add llama_chat_apply_template --- Makefile | 4 ++ llama.cpp | 108 +++++++++++++++++++++++++++++++++++ llama.h | 22 +++++++ tests/CMakeLists.txt | 1 + tests/test-chat-template.cpp | 68 ++++++++++++++++++++++ 5 files changed, 203 insertions(+) create mode 100644 tests/test-chat-template.cpp diff --git a/Makefile b/Makefile index 0a2070b539df8..8f1535b55cb95 100644 --- a/Makefile +++ b/Makefile @@ -862,3 +862,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + +tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) diff --git a/llama.cpp b/llama.cpp index 08e7b02b4cc1d..b13c238396f2b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12459,6 +12459,114 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } +int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass); + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (chat_template.find("<|im_start|>") != std::string::npos) { + // chatml template + for (auto message : conversation) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (chat_template.find("[INST]") != std::string::npos) { + // llama2 template and its variants + // [variant] support system message + bool support_system_message = chat_template.find("<>") != std::string::npos; + // [variant] space before + after response + bool space_around_response = chat_template.find("' ' + eos_token") != std::string::npos; + // [variant] add BOS inside history + bool add_bos_inside_history = chat_template.find("bos_token + '[INST]") != std::string::npos; + // [variant] trim spaces from the input message + bool strip_message = chat_template.find("content.strip()") != std::string::npos; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : conversation) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + is_inside_turn = false; + } + } + // llama2 templates seem to not care about "add_generation_prompt" + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * custom_template, + const struct llama_chat_message * msg, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length) { + std::string current_template(custom_template == nullptr ? "" : custom_template); + if (custom_template == nullptr) { + GGML_ASSERT(model != nullptr); + // load template from model + current_template.resize(2048); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size()); + if (res < 0) { + // worst case: there is no information about template, we will use chatml by default + current_template = "<|im_start|>"; // see llama_chat_apply_template_internal + } else { + current_template.resize(res); + } + } + // format the conversation to string + std::vector conversation_vec; + conversation_vec.resize(n_msg); + for (size_t i = 0; i < n_msg; i++) { + conversation_vec[i] = &msg[i]; + } + std::string formatted_chat; + int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass); + if (res < 0) { + return res; + } + strncpy(buf, formatted_chat.c_str(), length); + return res; +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index f4ec6ea6394a3..ff2c470198d6e 100644 --- a/llama.h +++ b/llama.h @@ -304,6 +304,12 @@ extern "C" { int32_t n_eval; }; + // used in chat template + typedef struct llama_chat_message { + const char * role; + const char * content; + } llama_chat_message; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); @@ -698,6 +704,22 @@ extern "C" { char * buf, int32_t length); + /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. + /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. + /// @param msg Pointer to a list of multiple llama_chat_message + /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. + /// @return If "tokenize" is set to false, the "buf" must be a string (returned value will be the string length). + /// Otherwise, "buf" must be a list of tokens (returned value will be the number of tokens). + LLAMA_API int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * custom_template, + const struct llama_chat_message * msg, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length); + // // Grammar // diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3e40a78cdeac9..10326d531e9fb 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -28,6 +28,7 @@ endfunction() llama_build_and_test_executable(test-quantize-fns.cpp) llama_build_and_test_executable(test-quantize-perf.cpp) llama_build_and_test_executable(test-sampling.cpp) +llama_build_and_test_executable(test-chat-template.cpp) llama_build_executable(test-tokenizer-0-llama.cpp) llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp new file mode 100644 index 0000000000000..23d02a7d2c2d8 --- /dev/null +++ b/tests/test-chat-template.cpp @@ -0,0 +1,68 @@ +#include +#include +#include +#include + +#undef NDEBUG +#include + +#include "llama.h" + +int main(void) { + llama_chat_message conversation[] = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "Who are you"}, + {"assistant", " I am an assistant "}, + {"user", "Another question"}, + }; + size_t message_count = 6; + std::vector conversation_vec; + conversation_vec.resize(message_count); + for (size_t i = 0; i < message_count; i++) { + conversation_vec[i] = &conversation[i]; + } + std::vector templates = { + // teknium/OpenHermes-2.5-Mistral-7B + "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + // mistralai/Mistral-7B-Instruct-v0.2 + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + // TheBloke/FusionNet_34Bx2_MoE-AWQ + "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", + // bofenghuang/vigogne-2-70b-chat + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + }; + std::vector expected_substr = { + "<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", + "[/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + "[/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }; + std::string formatted_chat; + int32_t res; + + // test invalid chat template + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size()); + assert(res < 0); + + for (size_t i = 0; i < templates.size(); i++) { + std::string custom_template = templates[i]; + std::string substr = expected_substr[i]; + formatted_chat.resize(1024); + res = llama_chat_apply_template( + nullptr, + custom_template.c_str(), + conversation, + message_count, + true, + &(*formatted_chat.begin()), + formatted_chat.size() + ); + formatted_chat.resize(res); + std::cout << formatted_chat << "\n-------------------------\n"; + // expect the "formatted_chat" to contain pre-defined strings + assert(formatted_chat.find(substr) != std::string::npos); + } + return 0; +} From bba75c792f7bf7001a0f02a5a0fe2a7bfee15d45 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 16 Feb 2024 16:33:49 +0100 Subject: [PATCH 2/8] test-chat-template: remove dedundant vector --- tests/test-chat-template.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 23d02a7d2c2d8..23f576778666c 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -18,11 +18,6 @@ int main(void) { {"user", "Another question"}, }; size_t message_count = 6; - std::vector conversation_vec; - conversation_vec.resize(message_count); - for (size_t i = 0; i < message_count; i++) { - conversation_vec[i] = &conversation[i]; - } std::vector templates = { // teknium/OpenHermes-2.5-Mistral-7B "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", From 9c4422fbe94c44da881075b892b774e3fe3e8440 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 16 Feb 2024 17:04:01 +0100 Subject: [PATCH 3/8] chat_template: do not use std::string for buffer --- llama.cpp | 6 +++--- tests/test-chat-template.cpp | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index b13c238396f2b..615a9b901bf41 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12542,14 +12542,14 @@ LLAMA_API int32_t llama_chat_apply_template( if (custom_template == nullptr) { GGML_ASSERT(model != nullptr); // load template from model - current_template.resize(2048); // longest known template is about 1200 bytes + std::vector model_template(2048, 0); // longest known template is about 1200 bytes std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), &(*current_template.begin()), current_template.size()); + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default current_template = "<|im_start|>"; // see llama_chat_apply_template_internal } else { - current_template.resize(res); + current_template = std::string(model_template.data(), model_template.size()); } } // format the conversation to string diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 23f576778666c..9830650d4f8dd 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -34,11 +34,11 @@ int main(void) { "[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", "[/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", }; - std::string formatted_chat; + std::vector formatted_chat(1024); int32_t res; // test invalid chat template - res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, &(*formatted_chat.begin()), formatted_chat.size()); + res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); for (size_t i = 0; i < templates.size(); i++) { @@ -51,13 +51,14 @@ int main(void) { conversation, message_count, true, - &(*formatted_chat.begin()), + formatted_chat.data(), formatted_chat.size() ); formatted_chat.resize(res); - std::cout << formatted_chat << "\n-------------------------\n"; + std::string output(formatted_chat.data(), formatted_chat.size()); + std::cout << output << "\n-------------------------\n"; // expect the "formatted_chat" to contain pre-defined strings - assert(formatted_chat.find(substr) != std::string::npos); + assert(output.find(substr) != std::string::npos); } return 0; } From 6012ad651f85621b264c974de67af47de24e758d Mon Sep 17 00:00:00 2001 From: ngxson Date: Sat, 17 Feb 2024 16:45:31 +0100 Subject: [PATCH 4/8] add clarification for llama_chat_apply_template --- llama.h | 1 + 1 file changed, 1 insertion(+) diff --git a/llama.h b/llama.h index ff2c470198d6e..e0fe2e0bbc04d 100644 --- a/llama.h +++ b/llama.h @@ -706,6 +706,7 @@ extern "C" { /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// NOTE: This function only support some know jinja templates. It is not a jinja parser. /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. /// @param msg Pointer to a list of multiple llama_chat_message /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. From 7a3eac8cb32961735e65613e088d0b29ee579a54 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sat, 17 Feb 2024 16:54:30 +0100 Subject: [PATCH 5/8] llama_chat_apply_template: add zephyr template --- llama.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/llama.cpp b/llama.cpp index 615a9b901bf41..dee685c197297 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12522,6 +12522,14 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t } } // llama2 templates seem to not care about "add_generation_prompt" + } else if (chat_template.find("<|user|>") != std::string::npos) { + // zephyr template + for (auto message : conversation) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else { // template not supported return -1; From 011af99af8770744cc213675962d648ff3c8ee2f Mon Sep 17 00:00:00 2001 From: ngxson Date: Sat, 17 Feb 2024 20:24:05 +0100 Subject: [PATCH 6/8] llama_chat_apply_template: correct docs --- llama.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.h b/llama.h index e0fe2e0bbc04d..64140bde2e598 100644 --- a/llama.h +++ b/llama.h @@ -706,12 +706,12 @@ extern "C" { /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" - /// NOTE: This function only support some know jinja templates. It is not a jinja parser. + /// NOTE: This function only support some known jinja templates. It is not a jinja parser. /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. /// @param msg Pointer to a list of multiple llama_chat_message /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. - /// @return If "tokenize" is set to false, the "buf" must be a string (returned value will be the string length). - /// Otherwise, "buf" must be a list of tokens (returned value will be the number of tokens). + /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) + /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * custom_template, From 73fbd6790110abb8790fb5d45febb0c32a1a29c9 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 18 Feb 2024 21:44:47 +0100 Subject: [PATCH 7/8] llama_chat_apply_template: use term "chat" everywhere --- llama.cpp | 28 +++++++++++++++++----------- llama.h | 10 ++++++---- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/llama.cpp b/llama.cpp index 199669517ee22..3c992d6f6976c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12459,7 +12459,10 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } -int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass); +int32_t llama_chat_apply_template_internal( + const std::string & chat_template, + const std::vector & chat, + std::string & dest, bool add_ass); // trim whitespace from the beginning and end of a string static std::string trim(const std::string & str) { @@ -12476,12 +12479,15 @@ static std::string trim(const std::string & str) { // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. -int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector conversation, bool add_ass) { +int32_t llama_chat_apply_template_internal( + const std::string & chat_template, + const std::vector & chat, + std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; if (chat_template.find("<|im_start|>") != std::string::npos) { // chatml template - for (auto message : conversation) { + for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; } if (add_ass) { @@ -12500,7 +12506,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; - for (auto message : conversation) { + for (auto message : chat) { std::string content = strip_message ? trim(message->content) : message->content; std::string role(message->role); if (!is_inside_turn) { @@ -12524,7 +12530,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t // llama2 templates seem to not care about "add_generation_prompt" } else if (chat_template.find("<|user|>") != std::string::npos) { // zephyr template - for (auto message : conversation) { + for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; } if (add_ass) { @@ -12541,7 +12547,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * custom_template, - const struct llama_chat_message * msg, + const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf, @@ -12560,14 +12566,14 @@ LLAMA_API int32_t llama_chat_apply_template( current_template = std::string(model_template.data(), model_template.size()); } } - // format the conversation to string - std::vector conversation_vec; - conversation_vec.resize(n_msg); + // format the chat to string + std::vector chat_vec; + chat_vec.resize(n_msg); for (size_t i = 0; i < n_msg; i++) { - conversation_vec[i] = &msg[i]; + chat_vec[i] = &chat[i]; } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass); + int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } diff --git a/llama.h b/llama.h index 64140bde2e598..3e53576852d7c 100644 --- a/llama.h +++ b/llama.h @@ -704,18 +704,20 @@ extern "C" { char * buf, int32_t length); - /// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python. + /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// NOTE: This function only support some known jinja templates. It is not a jinja parser. - /// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead. - /// @param msg Pointer to a list of multiple llama_chat_message + /// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param chat Pointer to a list of multiple llama_chat_message + /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) + /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, const char * custom_template, - const struct llama_chat_message * msg, + const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf, From 649f6f829057849a3826da89a3a8543f275a0fd0 Mon Sep 17 00:00:00 2001 From: ngxson Date: Sun, 18 Feb 2024 22:06:44 +0100 Subject: [PATCH 8/8] llama_chat_apply_template: change variable name to "tmpl" --- llama.cpp | 37 ++++++++++++++++--------------------- llama.h | 4 ++-- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3c992d6f6976c..59d858e030abb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12459,11 +12459,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return 0; } -int32_t llama_chat_apply_template_internal( - const std::string & chat_template, - const std::vector & chat, - std::string & dest, bool add_ass); - // trim whitespace from the beginning and end of a string static std::string trim(const std::string & str) { size_t start = 0; @@ -12479,13 +12474,13 @@ static std::string trim(const std::string & str) { // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. -int32_t llama_chat_apply_template_internal( - const std::string & chat_template, +static int32_t llama_chat_apply_template_internal( + const std::string & tmpl, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - if (chat_template.find("<|im_start|>") != std::string::npos) { + if (tmpl.find("<|im_start|>") != std::string::npos) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -12493,16 +12488,16 @@ int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (chat_template.find("[INST]") != std::string::npos) { + } else if (tmpl.find("[INST]") != std::string::npos) { // llama2 template and its variants // [variant] support system message - bool support_system_message = chat_template.find("<>") != std::string::npos; + bool support_system_message = tmpl.find("<>") != std::string::npos; // [variant] space before + after response - bool space_around_response = chat_template.find("' ' + eos_token") != std::string::npos; + bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos; // [variant] add BOS inside history - bool add_bos_inside_history = chat_template.find("bos_token + '[INST]") != std::string::npos; + bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos; // [variant] trim spaces from the input message - bool strip_message = chat_template.find("content.strip()") != std::string::npos; + bool strip_message = tmpl.find("content.strip()") != std::string::npos; // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; @@ -12528,7 +12523,7 @@ int32_t llama_chat_apply_template_internal( } } // llama2 templates seem to not care about "add_generation_prompt" - } else if (chat_template.find("<|user|>") != std::string::npos) { + } else if (tmpl.find("<|user|>") != std::string::npos) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -12546,24 +12541,24 @@ int32_t llama_chat_apply_template_internal( LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, - const char * custom_template, + const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, bool add_ass, char * buf, int32_t length) { - std::string current_template(custom_template == nullptr ? "" : custom_template); - if (custom_template == nullptr) { + std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); + if (tmpl == nullptr) { GGML_ASSERT(model != nullptr); // load template from model std::vector model_template(2048, 0); // longest known template is about 1200 bytes std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), current_template.size()); + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default - current_template = "<|im_start|>"; // see llama_chat_apply_template_internal + curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal } else { - current_template = std::string(model_template.data(), model_template.size()); + curr_tmpl = std::string(model_template.data(), model_template.size()); } } // format the chat to string @@ -12573,7 +12568,7 @@ LLAMA_API int32_t llama_chat_apply_template( chat_vec[i] = &chat[i]; } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass); + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } diff --git a/llama.h b/llama.h index 3e53576852d7c..a3813d1ead035 100644 --- a/llama.h +++ b/llama.h @@ -707,7 +707,7 @@ extern "C" { /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// NOTE: This function only support some known jinja templates. It is not a jinja parser. - /// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. @@ -716,7 +716,7 @@ extern "C" { /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( const struct llama_model * model, - const char * custom_template, + const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, bool add_ass,