Skip to content

Commit 9400bf6

Browse files
ngxsonNeoZhangJianyu
authored andcommitted
Add chatml fallback for cpp llama_chat_apply_template (ggml-org#8160)
* add chatml fallback for cpp `llama_chat_apply_template` * remove redundant code
1 parent b5c120d commit 9400bf6

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

common/common.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
26182618
const std::vector<llama_chat_msg> & msgs,
26192619
bool add_ass) {
26202620
int alloc_size = 0;
2621+
bool fallback = false; // indicate if we must fallback to default chatml
26212622
std::vector<llama_chat_message> chat;
26222623
for (auto & msg : msgs) {
26232624
chat.push_back({msg.role.c_str(), msg.content.c_str()});
@@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
26302631
// run the first time to get the total output length
26312632
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
26322633

2634+
// error: chat template is not supported
2635+
if (res < 0) {
2636+
if (ptr_tmpl != nullptr) {
2637+
// if the custom "tmpl" is not supported, we throw an error
2638+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
2639+
throw std::runtime_error("this custom template is not supported");
2640+
} else {
2641+
// If the built-in template is not supported, we default to chatml
2642+
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2643+
fallback = true;
2644+
}
2645+
}
2646+
26332647
// if it turns out that our buffer is too small, we resize it
26342648
if ((size_t) res > buf.size()) {
26352649
buf.resize(res);
2636-
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
2650+
res = llama_chat_apply_template(
2651+
fallback ? nullptr : model,
2652+
fallback ? "chatml" : ptr_tmpl,
2653+
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
26372654
}
26382655

26392656
std::string formatted_chat(buf.data(), res);

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,8 @@ struct llama_chat_msg {
380380
bool llama_chat_verify_template(const std::string & tmpl);
381381

382382
// CPP wrapper for llama_chat_apply_template
383+
// If the built-in template is not supported, we default to chatml
384+
// If the custom "tmpl" is not supported, we throw an error
383385
std::string llama_chat_apply_template(const struct llama_model * model,
384386
const std::string & tmpl,
385387
const std::vector<llama_chat_msg> & chat,

0 commit comments

Comments
 (0)