-
Notifications
You must be signed in to change notification settings - Fork 12k
Add llama_chat_apply_template() #5538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4e64440
llama: add llama_chat_apply_template
ngxson bba75c7
test-chat-template: remove dedundant vector
ngxson 9c4422f
chat_template: do not use std::string for buffer
ngxson 6012ad6
add clarification for llama_chat_apply_template
ngxson 7a3eac8
llama_chat_apply_template: add zephyr template
ngxson 011af99
llama_chat_apply_template: correct docs
ngxson dba4337
Merge branch 'master' into xsn/chat_apply_template
ngxson 73fbd67
llama_chat_apply_template: use term "chat" everywhere
ngxson 649f6f8
llama_chat_apply_template: change variable name to "tmpl"
ngxson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
#include <sstream> | ||
|
||
#undef NDEBUG | ||
#include <cassert> | ||
|
||
#include "llama.h" | ||
|
||
int main(void) { | ||
llama_chat_message conversation[] = { | ||
{"system", "You are a helpful assistant"}, | ||
{"user", "Hello"}, | ||
{"assistant", "Hi there"}, | ||
{"user", "Who are you"}, | ||
{"assistant", " I am an assistant "}, | ||
{"user", "Another question"}, | ||
}; | ||
size_t message_count = 6; | ||
std::vector<std::string> templates = { | ||
// teknium/OpenHermes-2.5-Mistral-7B | ||
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", | ||
// mistralai/Mistral-7B-Instruct-v0.2 | ||
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", | ||
// TheBloke/FusionNet_34Bx2_MoE-AWQ | ||
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", | ||
// bofenghuang/vigogne-2-70b-chat | ||
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", | ||
}; | ||
std::vector<std::string> expected_substr = { | ||
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", | ||
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", | ||
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]", | ||
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", | ||
}; | ||
std::vector<char> formatted_chat(1024); | ||
int32_t res; | ||
|
||
// test invalid chat template | ||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); | ||
assert(res < 0); | ||
|
||
for (size_t i = 0; i < templates.size(); i++) { | ||
std::string custom_template = templates[i]; | ||
std::string substr = expected_substr[i]; | ||
formatted_chat.resize(1024); | ||
res = llama_chat_apply_template( | ||
nullptr, | ||
custom_template.c_str(), | ||
conversation, | ||
message_count, | ||
true, | ||
formatted_chat.data(), | ||
formatted_chat.size() | ||
); | ||
formatted_chat.resize(res); | ||
std::string output(formatted_chat.data(), formatted_chat.size()); | ||
std::cout << output << "\n-------------------------\n"; | ||
// expect the "formatted_chat" to contain pre-defined strings | ||
assert(output.find(substr) != std::string::npos); | ||
} | ||
return 0; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that I made a mistake in this line: it should be
model_template.size()
, notcurr_tmpl.size()
. I'm fixing it in the next PR (using this function in server)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P/s: It took me almost 1 hr to figure out this error. Sorry if I accidentally put somebody else in the same situation as me.