Skip to content

server: add llama2 chat template #5425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/server/oai.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
using json = nlohmann::json;

inline static json oaicompat_completion_params_parse(
const json &body /* openai api json semantics */)
const json &body, /* openai api json semantics */
const std::string &chat_template)
{
json llama_params;
std::string formatted_prompt = chat_template == "chatml"
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)

llama_params["__oaicompat"] = true;

Expand All @@ -30,7 +34,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
llama_params["prompt"] = formatted_prompt;
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
Expand Down
22 changes: 20 additions & 2 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct server_params
std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public";
std::string chat_template = "chatml";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
Expand Down Expand Up @@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf(" --chat-template FORMAT_NAME");
printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str());
printf("\n");
}

Expand Down Expand Up @@ -2290,6 +2293,21 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
log_set_target(stdout);
LOG_INFO("logging to file is disabled.", {});
}
else if (arg == "--chat-template")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::string value(argv[i]);
if (value != "chatml" && value != "llama2") {
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
invalid_param = true;
break;
}
sparams.chat_template = value;
}
else if (arg == "--override-kv")
{
if (++i >= argc) {
Expand Down Expand Up @@ -2743,13 +2761,13 @@ int main(int argc, char **argv)


// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(json::parse(req.body));
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);

const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
Expand Down
30 changes: 30 additions & 0 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v
: default_value;
}

inline std::string format_llama2(std::vector<json> messages)
{
std::ostringstream output;
bool is_inside_turn = false;

for (auto it = messages.begin(); it != messages.end(); ++it) {
if (!is_inside_turn) {
output << "[INST] ";
}
std::string role = json_value(*it, "role", std::string("user"));
std::string content = json_value(*it, "content", std::string(""));
if (role == "system") {
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
is_inside_turn = true;
} else if (role == "user") {
output << content << " [/INST]";
is_inside_turn = true;
} else {
output << " " << content << " </s>";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there should be a space in front of the eos token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #5425 (comment) , the space is not consistent among different variations of this template. I leave it here to be sure. It changes nothing on the behavior of the model.

Ideally if I had access to tokenizer.chat_template, I can adapt the template the template to exactly what the model need (i.e. spaces, <<SYS>>,...). But since I don't have access to that part for now, I kinda have to make a "catch-all" template.

And also, what is important in this template is the placement of [INST], [/INST] and </s>. A model that never need trained to understand chatml will not understand that the message should be encapsulated between <|im_start|> and <|im_end|>. It will kind of "repeat what's already in the conversation".

As long as the model see [INST] some user prompt [/INST] some assistant responses </s>, it will behave just fine. The spaces does not matter, that's what I experience while testing this PR.

is_inside_turn = false;
}
}

LOG_VERBOSE("format_llama2", {{"text", output.str()}});

return output.str();
}

inline std::string format_chatml(std::vector<json> messages)
{
std::ostringstream chatml_msgs;
Expand All @@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector<json> messages)

chatml_msgs << "<|im_start|>assistant" << '\n';

LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});

return chatml_msgs.str();
}

Expand Down