Skip to content

Server: fallback to chatml, add AlphaMonarch chat template #5628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,16 @@ struct llama_server_context
return true;
}

void validate_model_chat_template(server_params & sparams) {
llama_chat_message chat[] = {{"user", "test"}};
std::vector<char> buf(1);
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
if (res < 0) {
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
}
}

void initialize() {
// create slots
all_slots_are_idle = true;
Expand Down Expand Up @@ -2713,6 +2723,11 @@ int main(int argc, char **argv)
LOG_INFO("model loaded", {});
}

if (sparams.chat_template.empty()) { // custom chat template is not supplied
// check if the template comes with the model is supported by us
llama.validate_model_chat_template(sparams);
}

// Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
// If API key is not set, skip validation
Expand Down
9 changes: 9 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12580,6 +12580,15 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl.find("bos_token + message['role']") != std::string::npos) {
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
for (auto message : chat) {
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
ss << bos << message->role << "\n" << message->content << "</s>\n";
}
if (add_ass) {
ss << "<s>assistant\n";
}
} else {
// template not supported
return -1;
Expand Down
23 changes: 15 additions & 8 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ int main(void) {
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
// bofenghuang/vigogne-2-70b-chat
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
// mlabonne/AlphaMonarch-7B
"{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
};
std::vector<std::string> expected_substr = {
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
// mistralai/Mistral-7B-Instruct-v0.2
"[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
// TheBloke/FusionNet_34Bx2_MoE-AWQ
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
// bofenghuang/vigogne-2-70b-chat
"[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
// mlabonne/AlphaMonarch-7B
"system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
};
std::vector<char> formatted_chat(1024);
int32_t res;
Expand All @@ -43,7 +51,7 @@ int main(void) {

for (size_t i = 0; i < templates.size(); i++) {
std::string custom_template = templates[i];
std::string substr = expected_substr[i];
std::string expected = expected_output[i];
formatted_chat.resize(1024);
res = llama_chat_apply_template(
nullptr,
Expand All @@ -57,8 +65,7 @@ int main(void) {
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
// expect the "formatted_chat" to contain pre-defined strings
assert(output.find(substr) != std::string::npos);
assert(output == expected);
}
return 0;
}