From 44aa3be13f29a4bdbe32ce97e5e13a553692c0b3 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:23:59 -0400 Subject: [PATCH 01/22] Add common grammars and json-schema-to-grammar utility function from llama.cpp --- llama_cpp/llama_grammar.py | 315 ++++++++++++++++++++++++++++++++++++- 1 file changed, 309 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8ff15658e..29431d957 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,4 +1,5 @@ -"""C++ implementation of the llama grammar parser.""" +"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" + # flake8: noqa from pathlib import Path import sys @@ -1056,8 +1057,7 @@ def print_rule( # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1102,9 +1102,7 @@ def print_rule( for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None: f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, ) + + +"""llama.cpp gbnf rules from vendor/llama.cpp/grammars""" + +ARITHMETIC_GBNF = """\ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* +""" + +C_GBNF = """\ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) +""" + +CHESS_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JAPANESE_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JSON_ARR_GBNF = """\ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + + +JSON_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)?""" + +LIST_GBNF = """\ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" +""" + +"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" +import json +import re +from typing import List, Optional + +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + +PRIMITIVE_RULES = { + "boolean": '("true" | "false") space', + "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', + "string": r""" "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" space """, + "null": '"null" space', +} + +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} + + +class SchemaConverter: + def __init__(self, prop_order): + self._prop_order = prop_order + self._rules = {"space": SPACE_RULE} + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + ) + return f'"{escaped}"' + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f"{esc_name}{i}" in self._rules: + i += 1 + key = f"{esc_name}{i}" + self._rules[key] = rule + return key + + def visit(self, schema, name): + schema_type = schema.get("type") + rule_name = name or "root" + + if "oneOf" in schema or "anyOf" in schema: + rule = " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') + for i, alt_schema in enumerate( + schema.get("oneOf") or schema["anyOf"] + ) + ) + ) + return self._add_rule(rule_name, rule) + + elif "const" in schema: + return self._add_rule(rule_name, self._format_literal(schema["const"])) + + elif "enum" in schema: + rule = " | ".join((self._format_literal(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) + + elif schema_type == "object" and "properties" in schema: + # TODO: `required` keyword + prop_order = self._prop_order + prop_pairs = sorted( + schema["properties"].items(), + # sort by position in prop_order (if specified) then by key + key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), + ) + + rule = '"{" space' + for i, (prop_name, prop_schema) in enumerate(prop_pairs): + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + if i > 0: + rule += ' "," space' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += ' "}" space' + + return self._add_rule(rule_name, rule) + + elif schema_type == "array" and "items" in schema: + # TODO `prefixItems` keyword + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' + ) + rule = ( + f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + ) + return self._add_rule(rule_name, rule) + + else: + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + return self._add_rule( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) + + def format_grammar(self): + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + + +def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): + prop_order = prop_order or [] + schema = json.load(schema) + prop_order = {name: idx for idx, name in enumerate(prop_order)} + converter = SchemaConverter(prop_order) + converter.visit(schema, "") + return converter.format_grammar() From 855d34c516ca16f4bcb8c6863e5cb4c2130b7b48 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:24:50 -0400 Subject: [PATCH 02/22] Pass functions to format function --- llama_cpp/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fdde7ea01..e6ad83da5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1625,6 +1625,8 @@ def create_chat_completion( format = llama_chat_format.get_chat_format(self.chat_format) result = format( messages=messages, + functions=functions, + function_call=function_call, ) prompt = result.prompt if result.stop is not None: From 34c1f46fb67115d26971a3db10b018ee47df32b5 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:29:52 -0400 Subject: [PATCH 03/22] Add basic functionary formatting --- llama_cpp/llama_chat_format.py | 121 +++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 9a09a28ee..5bb151667 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -320,3 +320,124 @@ def format_chatml( _messages.append((_roles["assistant"], None)) _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("functionary") +def format_functionary( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + **kwargs: Any, +) -> ChatFormatterResponse: + SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" + + def generate_schema_from_functions( + functions: List[llama_types.ChatCompletionFunctions], + namespace: str = "functions", + ): + """ + Convert functions schema to a schema that language models can understand. + """ + + schema = ( + "// Supported function definitions that should be called when necessary.\n" + ) + schema += f"namespace {namespace} {{\n\n" + + for function in functions: + # Convert a Function object to dict, if necessary + function_name = function["name"] + description = function.get("description", "") + schema += f"// {description}\n" + schema += f"type {function_name}" + + parameters = function.get("parameters", None) + schema += " = (_: {\n" + required_params = parameters.get("required", []) + for param_name, param in parameters.get("properties", {}).items(): + # Param Description + description = param.get("description") + if description is not None: + schema += f"// {description}\n" + + # Param Name + schema += f"{param_name}" + if param_name not in required_params: + schema += "?" + + # Param Type + param_type = param.get("type", "any") + if param_type == "integer": + param_type = "number" + if "enum" in param: + param_type = " | ".join([f'"{v}"' for v in param["enum"]]) + schema += f": {param_type},\n" + + schema += "}) => any;\n\n" + + schema += f"}} // namespace {namespace}" + + return schema + + def prepare_messages_for_inference( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + ): + all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + if functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=generate_schema_from_functions(functions) + ) + ) + + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=SYSTEM_MESSAGE + ) + ) + + for message in messages: + # Function call responses + if message["role"] == "function" and "name" in message: + message["name"] = f"functions.{message['name']}" + # Function call requests by assistant + if "function_call" in message: + message["function_call"][ + "name" + ] = f"functions.{message['function_call']['name']}" + all_messages.append(message) + + all_messages.append( + llama_types.ChatCompletionRequestMessage(role="assistant", content=None) + ) + + def message_to_str(msg: llama_types.ChatCompletionRequestMessage): + if msg["role"] == "system": + return f"system:\n{msg['content']}\n" + + elif msg["role"] == "function" and "name" in msg: + return f"function name={msg['name']}:\n{msg['content']}\n" + elif msg["role"] == "user": + if msg["content"] is None: + return "user:\n" + else: + return f"user:\n{msg['content']}\n" + elif msg["role"] == "assistant": + if msg["content"] is not None and "function_call" in msg: + return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif "function_call" in msg: + return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif msg["content"] is None: + return "assistant" + else: + return f"assistant:\n{msg['content']}\n" + else: + raise ValueError(f"Unsupported role: {msg['role']}") + + return "".join([message_to_str(msg) for msg in all_messages]) + + prompt = prepare_messages_for_inference(messages, functions) + return ChatFormatterResponse( + prompt=prompt, + stop=["user:", ""], + ) From f93a2bb1db7c2f3eeff9ee6b6d26980cea7ce459 Mon Sep 17 00:00:00 2001 From: Joe Still Date: Tue, 3 Oct 2023 13:17:42 -0500 Subject: [PATCH 04/22] #717: Add support for Huggingface Autotokenizer --- llama_cpp/llama_chat_format.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 9a09a28ee..c7b75b75e 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -320,3 +320,24 @@ def format_chatml( _messages.append((_roles["assistant"], None)) _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + +# eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1 +@register_chat_format("autotokenizer") +def format_autotokenizer( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + # https://huggingface.co/docs/transformers/main/chat_templating + # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format + # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json + import os + from transformers import AutoTokenizer + huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1 + print(huggingFaceModel) + if not huggingFaceModel: + raise Exception("HF_MODEL needs to be set in env to use chat format 'autotokenizer'") + tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel) + tokenizer.use_default_system_prompt = False + _prompt = tokenizer.apply_chat_template(messages, tokenize=False) + # Return formatted prompt and eos token by default + return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) From b9763918fc51fb37d5599ff1ec2eb792dfb6223f Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:45:39 -0400 Subject: [PATCH 05/22] refactor: Streamline message formatting - Introduce `BASE_TEMPLATE` for common chat formatting structure. - Implement a protocol-based `ChatFormatterTemplate` for custom formatters. - Add `Llama2Formatter` to handle specific Llama-2 formatting. - Create `ChatFormatter` class for registering and retrieving formatters. - Remove redundant functions like `_format_llama2`. Refactored the chat message formatting to use a more structured and extensible approach. Now supports multiple templates and ensures a cleaner codebase. --- llama_cpp/llama_chat_format.py | 187 ++++++++++++++++++++++++++------- 1 file changed, 148 insertions(+), 39 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 5bb151667..161082804 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,7 +1,154 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Union, Protocol +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + from . import llama_types +BASE_TEMPLATE = { + "roles": { + "system": { + "prefix": "<>", + "postfix": "<>", + "format": None, + }, + "user": { + "prefix": "[INST] ", + "postfix": " [/INST]", + "format": None, + }, + "assistant": { + "prefix": "", + "postfix": "", + "format": None, + }, + }, + "separators": { + "after_system": "\n", + "between_messages": "\n", + "end_of_response": "", + }, + "special_tokens": { + "bos_token": "", + "eos_token": "", + "unk_token": "", + }, + "default_termination": { + "role": "assistant", + "message": None, + }, +} + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +class ChatFormatterTemplate(Protocol): + def __init__(self, template: Dict[str, Any] = BASE_TEMPLATE): + self.template = template + + # NOTE: Override private methods in inheriting classes as needed. + def _get_system_message( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Get the first system message.""" + # NOTE: The system message is always the first element in a sequence, + # any other order should be considered undefined. + # If we always set the first element in the sequence to a system role, + # it makes sense to simply check the first element and test to see if it is a system role. + # This allows us to extract and return the system message from the list of messages + # with a constant time complexity. + try: + if messages[0]["role"] == "system": + # Retrieve role-specific formatting + role_prefix = self.template["roles"]["system"]["prefix"] + role_postfix = self.template["roles"]["system"]["postfix"] + # Extract the role-based message content + content = messages[0]["content"] + # Format the message content with the role's prefix and postfix + return role_prefix + content + role_postfix + return "" + except (IndexError, KeyError): + return "" + + def _map_roles( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> List[Tuple[str, Optional[str]]]: + """Map the message roles.""" + # Convert the messages into a list of (role, message) tuples + mapped_sequence = [] + for message in messages: + if message["role"] in ["user", "assistant"]: + # Retrieve role-specific formatting + role_prefix = self.template["roles"][message["role"]]["prefix"] + role_postfix = self.template["roles"][message["role"]]["postfix"] + # Format the message content with the role's prefix and postfix + formatted_message = role_prefix + message["content"] + role_postfix + # Map the formatted message to the sequence as a tuple + mapped_sequence.append((message["role"], formatted_message)) + return mapped_sequence + + def _format_messages( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Transforms a list of messages into the appropriate format for the model.""" + ... + + def parse_response( + self, + messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], + **kwargs, + ) -> ChatFormatterResponse: + ... + + +class Llama2Formatter(ChatFormatterTemplate): + def _format_messages( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Private method to format messages based on Llama2 template.""" + system_message = self._get_system_message(messages) + mapped_messages = self._map_roles(messages) + separator = self.template["separators"]["between_messages"] + end_of_response = self.template["separators"]["end_of_response"] + + formatted_msg = separator.join([msg for role, msg in mapped_messages if msg]) + return system_message + separator + formatted_msg + end_of_response + + def parse_messages( + self, + messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], + **kwargs, + ) -> ChatFormatterResponse: + """Parse messages and wrap in ChatFormatterResponse.""" + formatted_content = self._format_messages(messages) + return ChatFormatterResponse(prompt=formatted_content) + + +class ChatFormatter: + _chat_formatters: Dict[str, ChatFormatterTemplate] = {"llama-2": Llama2Formatter} + + def register_chat_format(self, cls, name: str): + self._chat_formatters[name] = cls + + def get_chat_format(self, name: str): + try: + return self._chat_formatters[name]() + except KeyError: + valid_formats = list(self._chat_formatters.keys()) + raise ValueError( + f"Invalid chat format: {name}. Valid formats: {valid_formats}" + ) + + def format(self, name: str, messages: List[dict]) -> str: + formatter = self.get_chat_format(name) + return formatter.format_messages(messages) + + def parse(self, name: str, raw_response: str) -> Tuple[str, List[dict]]: + formatter = self.get_chat_format(name) + return formatter.parse_response(raw_response) + def _get_system_message( messages: List[llama_types.ChatCompletionRequestMessage], @@ -25,19 +172,6 @@ def _map_roles( return output -def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the llama2 style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += message + " " - else: - ret += role + " " - return ret - - def _format_add_colon_single( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -110,15 +244,6 @@ class ChatFormatterResponse: stop: Optional[Union[str, List[str]]] = None -class ChatFormatter(Protocol): - def __call__( - self, - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, - ) -> ChatFormatterResponse: - ... - - _CHAT_FORMATS: Dict[str, ChatFormatter] = {} @@ -139,22 +264,6 @@ def get_chat_format(name: str): ) -@register_chat_format("llama-2") -def format_llama2( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="[INST]", assistant="[/INST]") - _sep = "\n\n" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_llama2(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) - - @register_chat_format("alpaca") def format_alpaca( messages: List[llama_types.ChatCompletionRequestMessage], From 5c3b892238e5eaeaf275c197e2c1b2fc62e00abb Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:52:17 -0400 Subject: [PATCH 06/22] test: Add Llama2Formatter tests - Introduce `test_llama_chat_formatters.py` for testing chat formatters. - Implement `test_llama2_formatter` to validate Llama2 message formatting. Added unit tests to ensure the correctness of the newly refactored Llama2Formatter. This ensures that message formatting adheres to the expected template. --- tests/test_llama_chat_formatters.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/test_llama_chat_formatters.py diff --git a/tests/test_llama_chat_formatters.py b/tests/test_llama_chat_formatters.py new file mode 100644 index 000000000..313b29a74 --- /dev/null +++ b/tests/test_llama_chat_formatters.py @@ -0,0 +1,29 @@ +from typing import List + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_chat_format import Llama2Formatter + +messages: List[ChatCompletionMessage] = [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage(role="user", content="Let's go with a recursive solution."), +] + + +def test_llama2_formatter(): + prompt = """<>Welcome to CodeHelp Bot!<>\n[INST] Hi there! I need some help with Python. [/INST]\nOf course! What do you need help with in Python?\n[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\nI can help with that! Would you like a recursive or iterative solution?\n[INST] Let's go with a recursive solution. [/INST]""" + llama2formatter = Llama2Formatter() + assert prompt == llama2formatter._format_messages(messages) From 63518122f4ca94e24331e67adcb02c3a2a1df413 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:58:21 -0400 Subject: [PATCH 07/22] test: Add pytest fixture for message sequence - Introduce pytest fixture `sequence_of_messages` in `test_llama_chat_formatters.py`. - Refactor `test_llama2_formatter` to use the new fixture. Utilizing pytest fixtures enhances the modularity of our test suite, allowing for cleaner test cases and potential reusability across multiple tests. --- tests/test_llama_chat_formatters.py | 47 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/test_llama_chat_formatters.py b/tests/test_llama_chat_formatters.py index 313b29a74..30e042bfa 100644 --- a/tests/test_llama_chat_formatters.py +++ b/tests/test_llama_chat_formatters.py @@ -1,29 +1,36 @@ from typing import List +import pytest + from llama_cpp import ChatCompletionMessage from llama_cpp.llama_chat_format import Llama2Formatter -messages: List[ChatCompletionMessage] = [ - ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), - ChatCompletionMessage( - role="user", content="Hi there! I need some help with Python." - ), - ChatCompletionMessage( - role="assistant", content="Of course! What do you need help with in Python?" - ), - ChatCompletionMessage( - role="user", - content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", - ), - ChatCompletionMessage( - role="assistant", - content="I can help with that! Would you like a recursive or iterative solution?", - ), - ChatCompletionMessage(role="user", content="Let's go with a recursive solution."), -] + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] -def test_llama2_formatter(): +def test_llama2_formatter(sequence_of_messages): prompt = """<>Welcome to CodeHelp Bot!<>\n[INST] Hi there! I need some help with Python. [/INST]\nOf course! What do you need help with in Python?\n[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\nI can help with that! Would you like a recursive or iterative solution?\n[INST] Let's go with a recursive solution. [/INST]""" llama2formatter = Llama2Formatter() - assert prompt == llama2formatter._format_messages(messages) + assert prompt == llama2formatter._format_messages(sequence_of_messages) From c8ffb0f479f04c2c5be21a82a5dc30e8feb48442 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:23:59 -0400 Subject: [PATCH 08/22] Add common grammars and json-schema-to-grammar utility function from llama.cpp --- llama_cpp/llama_grammar.py | 315 ++++++++++++++++++++++++++++++++++++- 1 file changed, 309 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8ff15658e..29431d957 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,4 +1,5 @@ -"""C++ implementation of the llama grammar parser.""" +"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" + # flake8: noqa from pathlib import Path import sys @@ -1056,8 +1057,7 @@ def print_rule( # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1102,9 +1102,7 @@ def print_rule( for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None: f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, ) + + +"""llama.cpp gbnf rules from vendor/llama.cpp/grammars""" + +ARITHMETIC_GBNF = """\ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* +""" + +C_GBNF = """\ +root ::= (declaration)* + +declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" + +dataType ::= "int" ws | "float" ws | "char" ws +identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* + +parameter ::= dataType identifier + +statement ::= + ( dataType identifier ws "=" ws expression ";" ) | + ( identifier ws "=" ws expression ";" ) | + ( identifier ws "(" argList? ")" ";" ) | + ( "return" ws expression ";" ) | + ( "while" "(" condition ")" "{" statement* "}" ) | + ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | + ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | + ( singleLineComment ) | + ( multiLineComment ) + +forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression +forUpdate ::= identifier ws "=" ws expression + +condition ::= expression relationOperator expression +relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") + +expression ::= term (("+" | "-") term)* +term ::= factor(("*" | "/") factor)* + +factor ::= identifier | number | unaryTerm | funcCall | parenExpression +unaryTerm ::= "-" factor +funcCall ::= identifier "(" argList? ")" +parenExpression ::= "(" ws expression ws ")" + +argList ::= expression ("," ws expression)* + +number ::= [0-9]+ + +singleLineComment ::= "//" [^\n]* "\n" +multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" + +ws ::= ([ \t\n]+) +""" + +CHESS_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JAPANESE_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + +JSON_ARR_GBNF = """\ +# This is the same as json.gbnf but we restrict whitespaces at the end of the root array +# Useful for generating JSON arrays + +root ::= arr +value ::= object | array | string | number | ("true" | "false" | "null") ws + +arr ::= + "[\n" ws ( + value + (",\n" ws value)* + )? "]" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? +""" + + +JSON_GBNF = """\ +root ::= object +value ::= object | array | string | number | ("true" | "false" | "null") ws + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)?""" + +LIST_GBNF = """\ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" +""" + +"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py""" +import json +import re +from typing import List, Optional + +# whitespace is constrained to a single space char to prevent model "running away" in +# whitespace. Also maybe improves generation quality? +SPACE_RULE = '" "?' + +PRIMITIVE_RULES = { + "boolean": '("true" | "false") space', + "number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space', + "integer": '("-"? ([0-9] | [1-9] [0-9]*)) space', + "string": r""" "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" space """, + "null": '"null" space', +} + +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'} + + +class SchemaConverter: + def __init__(self, prop_order): + self._prop_order = prop_order + self._rules = {"space": SPACE_RULE} + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal) + ) + return f'"{escaped}"' + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f"{esc_name}{i}" in self._rules: + i += 1 + key = f"{esc_name}{i}" + self._rules[key] = rule + return key + + def visit(self, schema, name): + schema_type = schema.get("type") + rule_name = name or "root" + + if "oneOf" in schema or "anyOf" in schema: + rule = " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else ""}{i}') + for i, alt_schema in enumerate( + schema.get("oneOf") or schema["anyOf"] + ) + ) + ) + return self._add_rule(rule_name, rule) + + elif "const" in schema: + return self._add_rule(rule_name, self._format_literal(schema["const"])) + + elif "enum" in schema: + rule = " | ".join((self._format_literal(v) for v in schema["enum"])) + return self._add_rule(rule_name, rule) + + elif schema_type == "object" and "properties" in schema: + # TODO: `required` keyword + prop_order = self._prop_order + prop_pairs = sorted( + schema["properties"].items(), + # sort by position in prop_order (if specified) then by key + key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]), + ) + + rule = '"{" space' + for i, (prop_name, prop_schema) in enumerate(prop_pairs): + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) + if i > 0: + rule += ' "," space' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += ' "}" space' + + return self._add_rule(rule_name, rule) + + elif schema_type == "array" and "items" in schema: + # TODO `prefixItems` keyword + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' + ) + rule = ( + f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + ) + return self._add_rule(rule_name, rule) + + else: + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" + return self._add_rule( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) + + def format_grammar(self): + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) + + +def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): + prop_order = prop_order or [] + schema = json.load(schema) + prop_order = {name: idx for idx, name in enumerate(prop_order)} + converter = SchemaConverter(prop_order) + converter.visit(schema, "") + return converter.format_grammar() From b8224445301559fe8141f1e80ffa3a320b1aeba3 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:24:50 -0400 Subject: [PATCH 09/22] Pass functions to format function --- llama_cpp/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index fdde7ea01..e6ad83da5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1625,6 +1625,8 @@ def create_chat_completion( format = llama_chat_format.get_chat_format(self.chat_format) result = format( messages=messages, + functions=functions, + function_call=function_call, ) prompt = result.prompt if result.stop is not None: From 7cfa5a7d9665248688425e32a57dc0b1f6c4b6ef Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 2 Oct 2023 11:29:52 -0400 Subject: [PATCH 10/22] Add basic functionary formatting --- llama_cpp/llama_chat_format.py | 133 ++++++++++++++++++++++++++++++++- 1 file changed, 130 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index c7b75b75e..c6716d95b 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,5 +1,6 @@ import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Union, Protocol +from typing import Any, Dict, List, Optional, Protocol, Tuple, Union + from . import llama_types @@ -321,6 +322,7 @@ def format_chatml( _prompt = _format_chatml(system_message, _messages, _sep) return ChatFormatterResponse(prompt=_prompt) + # eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1 @register_chat_format("autotokenizer") def format_autotokenizer( @@ -331,13 +333,138 @@ def format_autotokenizer( # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json import os + from transformers import AutoTokenizer - huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1 + + huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1 print(huggingFaceModel) if not huggingFaceModel: - raise Exception("HF_MODEL needs to be set in env to use chat format 'autotokenizer'") + raise Exception( + "HF_MODEL needs to be set in env to use chat format 'autotokenizer'" + ) tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel) tokenizer.use_default_system_prompt = False _prompt = tokenizer.apply_chat_template(messages, tokenize=False) # Return formatted prompt and eos token by default return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) + + +@register_chat_format("functionary") +def format_functionary( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + **kwargs: Any, +) -> ChatFormatterResponse: + SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" + + def generate_schema_from_functions( + functions: List[llama_types.ChatCompletionFunctions], + namespace: str = "functions", + ): + """ + Convert functions schema to a schema that language models can understand. + """ + + schema = ( + "// Supported function definitions that should be called when necessary.\n" + ) + schema += f"namespace {namespace} {{\n\n" + + for function in functions: + # Convert a Function object to dict, if necessary + function_name = function["name"] + description = function.get("description", "") + schema += f"// {description}\n" + schema += f"type {function_name}" + + parameters = function.get("parameters", None) + schema += " = (_: {\n" + required_params = parameters.get("required", []) + for param_name, param in parameters.get("properties", {}).items(): + # Param Description + description = param.get("description") + if description is not None: + schema += f"// {description}\n" + + # Param Name + schema += f"{param_name}" + if param_name not in required_params: + schema += "?" + + # Param Type + param_type = param.get("type", "any") + if param_type == "integer": + param_type = "number" + if "enum" in param: + param_type = " | ".join([f'"{v}"' for v in param["enum"]]) + schema += f": {param_type},\n" + + schema += "}) => any;\n\n" + + schema += f"}} // namespace {namespace}" + + return schema + + def prepare_messages_for_inference( + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + ): + all_messages: List[llama_types.ChatCompletionRequestMessage] = [] + if functions is not None: + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=generate_schema_from_functions(functions) + ) + ) + + all_messages.append( + llama_types.ChatCompletionRequestMessage( + role="system", content=SYSTEM_MESSAGE + ) + ) + + for message in messages: + # Function call responses + if message["role"] == "function" and "name" in message: + message["name"] = f"functions.{message['name']}" + # Function call requests by assistant + if "function_call" in message: + message["function_call"][ + "name" + ] = f"functions.{message['function_call']['name']}" + all_messages.append(message) + + all_messages.append( + llama_types.ChatCompletionRequestMessage(role="assistant", content=None) + ) + + def message_to_str(msg: llama_types.ChatCompletionRequestMessage): + if msg["role"] == "system": + return f"system:\n{msg['content']}\n" + + elif msg["role"] == "function" and "name" in msg: + return f"function name={msg['name']}:\n{msg['content']}\n" + elif msg["role"] == "user": + if msg["content"] is None: + return "user:\n" + else: + return f"user:\n{msg['content']}\n" + elif msg["role"] == "assistant": + if msg["content"] is not None and "function_call" in msg: + return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif "function_call" in msg: + return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + elif msg["content"] is None: + return "assistant" + else: + return f"assistant:\n{msg['content']}\n" + else: + raise ValueError(f"Unsupported role: {msg['role']}") + + return "".join([message_to_str(msg) for msg in all_messages]) + + prompt = prepare_messages_for_inference(messages, functions) + return ChatFormatterResponse( + prompt=prompt, + stop=["user:", ""], + ) From edce4523b57e3c4637999b78bf07310ba3ae476b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 3 Oct 2023 15:23:35 -0400 Subject: [PATCH 11/22] Update llama.cpp --- llama_cpp/llama_cpp.py | 61 ++++++++++++++++++++++++++++++++++++++++-- vendor/llama.cpp | 2 +- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index d2a35c13f..41c87e20b 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -102,8 +102,8 @@ def _load_shared_library(lib_base_name: str): # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN -# define LLAMA_SESSION_VERSION 1 -LLAMA_SESSION_VERSION = 1 +# define LLAMA_SESSION_VERSION 2 +LLAMA_SESSION_VERSION = 2 # struct llama_model; @@ -624,6 +624,16 @@ def llama_n_embd(model: llama_model_p) -> int: _lib.llama_n_embd.restype = c_int +# // Get the model's RoPE frequency scaling factor +# LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); +def llama_rope_freq_scale_train(model: llama_model_p) -> float: + return _lib.llama_rope_freq_scale_train(model) + + +_lib.llama_rope_freq_scale_train.argtypes = [llama_model_p] +_lib.llama_rope_freq_scale_train.restype = c_float + + # // Get a string describing the model type # LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); def llama_model_desc( @@ -768,6 +778,8 @@ def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int: # // Remove all tokens data of cells in [c0, c1) +# // c0 < 0 : [0, c1] +# // c1 < 0 : [c0, inf) # LLAMA_API void llama_kv_cache_tokens_rm( # struct llama_context * ctx, # int32_t c0, @@ -783,6 +795,8 @@ def llama_kv_cache_tokens_rm( # // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) +# // p0 < 0 : [0, p1] +# // p1 < 0 : [p0, inf) # LLAMA_API void llama_kv_cache_seq_rm( # struct llama_context * ctx, # llama_seq_id seq_id, @@ -808,6 +822,8 @@ def llama_kv_cache_seq_rm( # // Copy all tokens that belong to the specified sequence to another sequence # // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence +# // p0 < 0 : [0, p1] +# // p1 < 0 : [p0, inf) # LLAMA_API void llama_kv_cache_seq_cp( # struct llama_context * ctx, # llama_seq_id seq_id_src, @@ -851,6 +867,8 @@ def llama_kv_cache_seq_keep( # // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) # // If the KV cache is RoPEd, the KV data is updated accordingly +# // p0 < 0 : [0, p1] +# // p1 < 0 : [p0, inf) # LLAMA_API void llama_kv_cache_seq_shift( # struct llama_context * ctx, # llama_seq_id seq_id, @@ -1215,6 +1233,43 @@ def llama_token_nl(ctx: llama_context_p) -> int: _lib.llama_token_nl.restype = llama_token +# // codellama infill tokens +# LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix +def llama_token_prefix(ctx: llama_context_p) -> int: + return _lib.llama_token_prefix(ctx) + + +_lib.llama_token_prefix.argtypes = [llama_context_p] +_lib.llama_token_prefix.restype = llama_token + + +# LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle +def llama_token_middle(ctx: llama_context_p) -> int: + return _lib.llama_token_middle(ctx) + + +_lib.llama_token_middle.argtypes = [llama_context_p] +_lib.llama_token_middle.restype = llama_token + + +# LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix +def llama_token_suffix(ctx: llama_context_p) -> int: + return _lib.llama_token_suffix(ctx) + + +_lib.llama_token_suffix.argtypes = [llama_context_p] +_lib.llama_token_suffix.restype = llama_token + + +# LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle +def llama_token_eot(ctx: llama_context_p) -> int: + return _lib.llama_token_eot(ctx) + + +_lib.llama_token_eot.argtypes = [llama_context_p] +_lib.llama_token_eot.restype = llama_token + + # // # // Tokenization # // @@ -1728,6 +1783,7 @@ def llama_grammar_accept_token( # struct llama_beam_view { # const llama_token * tokens; + # size_t n_tokens; # float p; // Cumulative beam probability (renormalized relative to all beams) # bool eob; // Callback should set this to true when a beam is at end-of-beam. @@ -1794,6 +1850,7 @@ def llama_beam_search( ctx, callback, callback_data, n_beams, n_past, n_predict ) + _lib.llama_beam_search.argtypes = [ llama_context_p, llama_beam_search_callback_fn_t, diff --git a/vendor/llama.cpp b/vendor/llama.cpp index f5ef5cfb1..79f34abdd 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit f5ef5cfb18148131fcf45bdd2331f0db5ab7c3d0 +Subproject commit 79f34abddb72ac5ddbf118f3d87520b611a10a7d From 0d4a0bf5c84dc91d0fb3c6e858c17b83f2ac6338 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 4 Oct 2023 20:19:31 -0400 Subject: [PATCH 12/22] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 79f34abdd..019ba1dcd 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 79f34abddb72ac5ddbf118f3d87520b611a10a7d +Subproject commit 019ba1dcd0c7775a5ac0f7442634a330eb0173cc From 8e71ac85467f930007b7e6062059ada3de03b90e Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 5 Oct 2023 16:07:49 -0400 Subject: [PATCH 13/22] Update llama.cpp --- llama_cpp/llama_cpp.py | 2 +- vendor/llama.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 41c87e20b..42e57a69c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -491,7 +491,7 @@ def llama_backend_free(): # LLAMA_API struct llama_model * llama_load_model_from_file( # const char * path_model, -# struct llama_context_params params); +# struct llama_model_params params); def llama_load_model_from_file( path_model: bytes, params: llama_model_params ) -> llama_model_p: diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 019ba1dcd..48edda30e 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 019ba1dcd0c7775a5ac0f7442634a330eb0173cc +Subproject commit 48edda30ee545fdac2e7a33d505382888f748bbf From 361e25460aa5c9f1f2bcba3d6116faefe41e66e3 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:45:39 -0400 Subject: [PATCH 14/22] refactor: Streamline message formatting - Introduce `BASE_TEMPLATE` for common chat formatting structure. - Implement a protocol-based `ChatFormatterTemplate` for custom formatters. - Add `Llama2Formatter` to handle specific Llama-2 formatting. - Create `ChatFormatter` class for registering and retrieving formatters. - Remove redundant functions like `_format_llama2`. Refactored the chat message formatting to use a more structured and extensible approach. Now supports multiple templates and ensures a cleaner codebase. --- llama_cpp/llama_chat_format.py | 184 ++++++++++++++++++++++++++------- 1 file changed, 146 insertions(+), 38 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index c6716d95b..e5c7022b1 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3,6 +3,152 @@ from . import llama_types +BASE_TEMPLATE = { + "roles": { + "system": { + "prefix": "<>", + "postfix": "<>", + "format": None, + }, + "user": { + "prefix": "[INST] ", + "postfix": " [/INST]", + "format": None, + }, + "assistant": { + "prefix": "", + "postfix": "", + "format": None, + }, + }, + "separators": { + "after_system": "\n", + "between_messages": "\n", + "end_of_response": "", + }, + "special_tokens": { + "bos_token": "", + "eos_token": "", + "unk_token": "", + }, + "default_termination": { + "role": "assistant", + "message": None, + }, +} + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +class ChatFormatterTemplate(Protocol): + def __init__(self, template: Dict[str, Any] = BASE_TEMPLATE): + self.template = template + + # NOTE: Override private methods in inheriting classes as needed. + def _get_system_message( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Get the first system message.""" + # NOTE: The system message is always the first element in a sequence, + # any other order should be considered undefined. + # If we always set the first element in the sequence to a system role, + # it makes sense to simply check the first element and test to see if it is a system role. + # This allows us to extract and return the system message from the list of messages + # with a constant time complexity. + try: + if messages[0]["role"] == "system": + # Retrieve role-specific formatting + role_prefix = self.template["roles"]["system"]["prefix"] + role_postfix = self.template["roles"]["system"]["postfix"] + # Extract the role-based message content + content = messages[0]["content"] + # Format the message content with the role's prefix and postfix + return role_prefix + content + role_postfix + return "" + except (IndexError, KeyError): + return "" + + def _map_roles( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> List[Tuple[str, Optional[str]]]: + """Map the message roles.""" + # Convert the messages into a list of (role, message) tuples + mapped_sequence = [] + for message in messages: + if message["role"] in ["user", "assistant"]: + # Retrieve role-specific formatting + role_prefix = self.template["roles"][message["role"]]["prefix"] + role_postfix = self.template["roles"][message["role"]]["postfix"] + # Format the message content with the role's prefix and postfix + formatted_message = role_prefix + message["content"] + role_postfix + # Map the formatted message to the sequence as a tuple + mapped_sequence.append((message["role"], formatted_message)) + return mapped_sequence + + def _format_messages( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Transforms a list of messages into the appropriate format for the model.""" + ... + + def parse_response( + self, + messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], + **kwargs, + ) -> ChatFormatterResponse: + ... + + +class Llama2Formatter(ChatFormatterTemplate): + def _format_messages( + self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] + ) -> str: + """Private method to format messages based on Llama2 template.""" + system_message = self._get_system_message(messages) + mapped_messages = self._map_roles(messages) + separator = self.template["separators"]["between_messages"] + end_of_response = self.template["separators"]["end_of_response"] + + formatted_msg = separator.join([msg for role, msg in mapped_messages if msg]) + return system_message + separator + formatted_msg + end_of_response + + def parse_messages( + self, + messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], + **kwargs, + ) -> ChatFormatterResponse: + """Parse messages and wrap in ChatFormatterResponse.""" + formatted_content = self._format_messages(messages) + return ChatFormatterResponse(prompt=formatted_content) + + +class ChatFormatter: + _chat_formatters: Dict[str, ChatFormatterTemplate] = {"llama-2": Llama2Formatter} + + def register_chat_format(self, cls, name: str): + self._chat_formatters[name] = cls + + def get_chat_format(self, name: str): + try: + return self._chat_formatters[name]() + except KeyError: + valid_formats = list(self._chat_formatters.keys()) + raise ValueError( + f"Invalid chat format: {name}. Valid formats: {valid_formats}" + ) + + def format(self, name: str, messages: List[dict]) -> str: + formatter = self.get_chat_format(name) + return formatter.format_messages(messages) + + def parse(self, name: str, raw_response: str) -> Tuple[str, List[dict]]: + formatter = self.get_chat_format(name) + return formatter.parse_response(raw_response) + def _get_system_message( messages: List[llama_types.ChatCompletionRequestMessage], @@ -26,19 +172,6 @@ def _map_roles( return output -def _format_llama2( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the llama2 style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += message + " " - else: - ret += role + " " - return ret - - def _format_add_colon_single( system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str ) -> str: @@ -111,15 +244,6 @@ class ChatFormatterResponse: stop: Optional[Union[str, List[str]]] = None -class ChatFormatter(Protocol): - def __call__( - self, - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, - ) -> ChatFormatterResponse: - ... - - _CHAT_FORMATS: Dict[str, ChatFormatter] = {} @@ -140,22 +264,6 @@ def get_chat_format(name: str): ) -@register_chat_format("llama-2") -def format_llama2( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="[INST]", assistant="[/INST]") - _sep = "\n\n" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_llama2(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) - - @register_chat_format("alpaca") def format_alpaca( messages: List[llama_types.ChatCompletionRequestMessage], From 10583854548f3d3f6e8b7fd9a9cfdb52606aebe0 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:52:17 -0400 Subject: [PATCH 15/22] test: Add Llama2Formatter tests - Introduce `test_llama_chat_formatters.py` for testing chat formatters. - Implement `test_llama2_formatter` to validate Llama2 message formatting. Added unit tests to ensure the correctness of the newly refactored Llama2Formatter. This ensures that message formatting adheres to the expected template. --- tests/test_llama_chat_formatters.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/test_llama_chat_formatters.py diff --git a/tests/test_llama_chat_formatters.py b/tests/test_llama_chat_formatters.py new file mode 100644 index 000000000..313b29a74 --- /dev/null +++ b/tests/test_llama_chat_formatters.py @@ -0,0 +1,29 @@ +from typing import List + +from llama_cpp import ChatCompletionMessage +from llama_cpp.llama_chat_format import Llama2Formatter + +messages: List[ChatCompletionMessage] = [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage(role="user", content="Let's go with a recursive solution."), +] + + +def test_llama2_formatter(): + prompt = """<>Welcome to CodeHelp Bot!<>\n[INST] Hi there! I need some help with Python. [/INST]\nOf course! What do you need help with in Python?\n[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\nI can help with that! Would you like a recursive or iterative solution?\n[INST] Let's go with a recursive solution. [/INST]""" + llama2formatter = Llama2Formatter() + assert prompt == llama2formatter._format_messages(messages) From 8cd236ea2f4360265f3c18175c9c82e5a9a28185 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:58:21 -0400 Subject: [PATCH 16/22] test: Add pytest fixture for message sequence - Introduce pytest fixture `sequence_of_messages` in `test_llama_chat_formatters.py`. - Refactor `test_llama2_formatter` to use the new fixture. Utilizing pytest fixtures enhances the modularity of our test suite, allowing for cleaner test cases and potential reusability across multiple tests. --- tests/test_llama_chat_formatters.py | 47 +++++++++++++++++------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tests/test_llama_chat_formatters.py b/tests/test_llama_chat_formatters.py index 313b29a74..30e042bfa 100644 --- a/tests/test_llama_chat_formatters.py +++ b/tests/test_llama_chat_formatters.py @@ -1,29 +1,36 @@ from typing import List +import pytest + from llama_cpp import ChatCompletionMessage from llama_cpp.llama_chat_format import Llama2Formatter -messages: List[ChatCompletionMessage] = [ - ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), - ChatCompletionMessage( - role="user", content="Hi there! I need some help with Python." - ), - ChatCompletionMessage( - role="assistant", content="Of course! What do you need help with in Python?" - ), - ChatCompletionMessage( - role="user", - content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", - ), - ChatCompletionMessage( - role="assistant", - content="I can help with that! Would you like a recursive or iterative solution?", - ), - ChatCompletionMessage(role="user", content="Let's go with a recursive solution."), -] + +@pytest.fixture +def sequence_of_messages() -> List[ChatCompletionMessage]: + return [ + ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), + ChatCompletionMessage( + role="user", content="Hi there! I need some help with Python." + ), + ChatCompletionMessage( + role="assistant", content="Of course! What do you need help with in Python?" + ), + ChatCompletionMessage( + role="user", + content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", + ), + ChatCompletionMessage( + role="assistant", + content="I can help with that! Would you like a recursive or iterative solution?", + ), + ChatCompletionMessage( + role="user", content="Let's go with a recursive solution." + ), + ] -def test_llama2_formatter(): +def test_llama2_formatter(sequence_of_messages): prompt = """<>Welcome to CodeHelp Bot!<>\n[INST] Hi there! I need some help with Python. [/INST]\nOf course! What do you need help with in Python?\n[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\nI can help with that! Would you like a recursive or iterative solution?\n[INST] Let's go with a recursive solution. [/INST]""" llama2formatter = Llama2Formatter() - assert prompt == llama2formatter._format_messages(messages) + assert prompt == llama2formatter._format_messages(sequence_of_messages) From 171a8d66572101215339b6bfa03682b1a9efef3c Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:00:02 -0400 Subject: [PATCH 17/22] refactor: Chat formatting for enhanced flexibility - Introduced `TokenizerCache` to efficiently reuse tokenizers. - Merged specific formatter classes into a generic `ChatFormatterTemplate` leveraging HuggingFace's `AutoTokenizer` and Jinja2 template capabilities. - Simplified the `ChatFormatter` class to manage chat format registrations and perform formatting and parsing operations. - Reduced overall source lines of code while enhancing code clarity and maintainability. Note: This refactor aims to provide a more flexible and extensible approach to chat formatting, making it easier to add and manage different model templates in the future. --- llama_cpp/llama_chat_format.py | 324 +++++++++------------------------ 1 file changed, 81 insertions(+), 243 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index e5c7022b1..72677f095 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,269 +1,107 @@ import dataclasses -from typing import Any, Dict, List, Optional, Protocol, Tuple, Union +from typing import Dict, List, Optional -from . import llama_types - -BASE_TEMPLATE = { - "roles": { - "system": { - "prefix": "<>", - "postfix": "<>", - "format": None, - }, - "user": { - "prefix": "[INST] ", - "postfix": " [/INST]", - "format": None, - }, - "assistant": { - "prefix": "", - "postfix": "", - "format": None, - }, - }, - "separators": { - "after_system": "\n", - "between_messages": "\n", - "end_of_response": "", - }, - "special_tokens": { - "bos_token": "", - "eos_token": "", - "unk_token": "", - }, - "default_termination": { - "role": "assistant", - "message": None, - }, -} +from transformers import AutoTokenizer +from . import llama_types +# NOTE: Custom Templates use Jinja2. +# If no template is given, then should default to hf's tokenizer template. +# We can define the model and template on a model-to-model basis, +# however, this should be allowed to be overridden for flexibility and extensibility. +# We only need 2 keys, the model name and the jinja2 template. +# +# template = {"model": "meta-llama/Llama-2-7b-chat-hf", "template": None} +# +# or +# +# chat_template = { +# "model": "meta-llama/Llama-2-7b-chat-hf", +# "jinja": "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + message['content'].strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '[ASST] ' + message ['content'] + ' [/ASST]' + eos_token }}{% endif %}{% endfor %}", +# } +# +# We can probably employ some kind of method for reading a template it in from a file in necessary. +# +# We leave template empty here because HuggingFace defined it already. +# +# Source: https://huggingface.co/docs/transformers/main/chat_templating +# +# Special Thanks and Credit goes to bioshazard for the idea and preliminary implementation. +# Source: https://github.com/abetlen/llama-cpp-python/pull/790 + + +# NOTE: We can still use this for reverse compatibility with the currently employed API. +# This can be modified, if needed, in the future. @dataclasses.dataclass class ChatFormatterResponse: prompt: str - stop: Optional[Union[str, List[str]]] = None - - -class ChatFormatterTemplate(Protocol): - def __init__(self, template: Dict[str, Any] = BASE_TEMPLATE): - self.template = template - - # NOTE: Override private methods in inheriting classes as needed. - def _get_system_message( - self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] - ) -> str: - """Get the first system message.""" - # NOTE: The system message is always the first element in a sequence, - # any other order should be considered undefined. - # If we always set the first element in the sequence to a system role, - # it makes sense to simply check the first element and test to see if it is a system role. - # This allows us to extract and return the system message from the list of messages - # with a constant time complexity. - try: - if messages[0]["role"] == "system": - # Retrieve role-specific formatting - role_prefix = self.template["roles"]["system"]["prefix"] - role_postfix = self.template["roles"]["system"]["postfix"] - # Extract the role-based message content - content = messages[0]["content"] - # Format the message content with the role's prefix and postfix - return role_prefix + content + role_postfix - return "" - except (IndexError, KeyError): - return "" - - def _map_roles( - self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] - ) -> List[Tuple[str, Optional[str]]]: - """Map the message roles.""" - # Convert the messages into a list of (role, message) tuples - mapped_sequence = [] - for message in messages: - if message["role"] in ["user", "assistant"]: - # Retrieve role-specific formatting - role_prefix = self.template["roles"][message["role"]]["prefix"] - role_postfix = self.template["roles"][message["role"]]["postfix"] - # Format the message content with the role's prefix and postfix - formatted_message = role_prefix + message["content"] + role_postfix - # Map the formatted message to the sequence as a tuple - mapped_sequence.append((message["role"], formatted_message)) - return mapped_sequence - - def _format_messages( - self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] - ) -> str: - """Transforms a list of messages into the appropriate format for the model.""" - ... - - def parse_response( - self, - messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], - **kwargs, - ) -> ChatFormatterResponse: - ... - - -class Llama2Formatter(ChatFormatterTemplate): - def _format_messages( - self, messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]] - ) -> str: - """Private method to format messages based on Llama2 template.""" - system_message = self._get_system_message(messages) - mapped_messages = self._map_roles(messages) - separator = self.template["separators"]["between_messages"] - end_of_response = self.template["separators"]["end_of_response"] - - formatted_msg = separator.join([msg for role, msg in mapped_messages if msg]) - return system_message + separator + formatted_msg + end_of_response - - def parse_messages( - self, - messages: List[Dict[str, llama_types.ChatCompletionRequestMessage]], - **kwargs, - ) -> ChatFormatterResponse: - """Parse messages and wrap in ChatFormatterResponse.""" - formatted_content = self._format_messages(messages) - return ChatFormatterResponse(prompt=formatted_content) + stop: Optional[List[str]] = None -class ChatFormatter: - _chat_formatters: Dict[str, ChatFormatterTemplate] = {"llama-2": Llama2Formatter} - - def register_chat_format(self, cls, name: str): - self._chat_formatters[name] = cls - - def get_chat_format(self, name: str): - try: - return self._chat_formatters[name]() - except KeyError: - valid_formats = list(self._chat_formatters.keys()) - raise ValueError( - f"Invalid chat format: {name}. Valid formats: {valid_formats}" - ) - - def format(self, name: str, messages: List[dict]) -> str: - formatter = self.get_chat_format(name) - return formatter.format_messages(messages) +class TokenizerCache: + _cache: Dict[str, AutoTokenizer] = {} - def parse(self, name: str, raw_response: str) -> Tuple[str, List[dict]]: - formatter = self.get_chat_format(name) - return formatter.parse_response(raw_response) + @classmethod + def get_tokenizer(cls, model_name: str) -> AutoTokenizer: + if model_name not in cls._cache: + cls._cache[model_name] = AutoTokenizer.from_pretrained(model_name) + return cls._cache[model_name] -def _get_system_message( - messages: List[llama_types.ChatCompletionRequestMessage], -) -> str: - """Get the first system message.""" - for message in messages: - if message["role"] == "system": - return message["content"] or "" - return "" - - -def _map_roles( - messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] -) -> List[Tuple[str, Optional[str]]]: - """Map the message roles.""" - output: List[Tuple[str, Optional[str]]] = [] - for message in messages: - role = message["role"] - if role in role_map: - output.append((role_map[role], message["content"])) - return output - - -def _format_add_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ":" - return ret - - -def _format_add_colon_two( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str -) -> str: - """Format the prompt with the add-colon-two style.""" - seps = [sep, sep2] - ret = system_message + seps[0] - for i, (role, message) in enumerate(messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - - -def _format_no_colon_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the no-colon-single style.""" - ret = system_message - for role, message in messages: - if message: - ret += role + message + sep +class ChatFormatterTemplate: + def __init__(self, template: Optional[Dict[str, str]] = None): + if template: + self.template = template else: - ret += role - return ret - - -def _format_add_colon_space_single( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the add-colon-space-single style.""" - ret = system_message + sep - for role, message in messages: - if message: - ret += role + ": " + message + sep - else: - ret += role + ": " # must be end with a space - return ret - - -def _format_chatml( - system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str -) -> str: - """Format the prompt with the chatml style.""" - ret = "" if system_message == "" else system_message + sep + "\n" - for role, message in messages: - if message: - ret += role + "\n" + message + sep + "\n" - else: - ret += role + "\n" - return ret - + self.template = { + "model": "meta-llama/Llama-2-7b-chat-hf", + "jinja": None, + "tokenize": False, + } + self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"]) + + def _format_messages(self, messages: List[Dict[str, str]]) -> str: + # If a custom template is provided, override the tokenizer's default template + if self.template.get("jinja"): + self.tokenizer.chat_template = self.template["jinja"] + + return self.tokenizer.apply_chat_template( + messages, tokenize=self.template["tokenize"] + ) -@dataclasses.dataclass -class ChatFormatterResponse: - prompt: str - stop: Optional[Union[str, List[str]]] = None + def parse_response(self, messages: List[Dict[str, str]]) -> ChatFormatterResponse: + formatted_content = self._format_messages(messages) + return ChatFormatterResponse( + prompt=formatted_content, stop=[self.tokenizer.eos_token] + ) -_CHAT_FORMATS: Dict[str, ChatFormatter] = {} +class ChatFormatter: + _chat_formatters: Dict[str, ChatFormatterTemplate] = {} + def register_chat_format( + self, model_name: str, template: Optional[Dict[str, str]] = None + ): + self._chat_formatters[model_name] = ChatFormatterTemplate(template) -def register_chat_format(name: str): - def decorator(f: ChatFormatter): - _CHAT_FORMATS[name] = f - return f + def get_chat_format(self, model_name: str) -> ChatFormatterTemplate: + if model_name not in self._chat_formatters: + raise ValueError(f"Model {model_name} is not registered.") - return decorator + return self._chat_formatters[model_name] + def format(self, model_name: str, messages: List[Dict[str, str]]) -> str: + formatter = self.get_chat_format(model_name) + return formatter._format_messages(messages) -def get_chat_format(name: str): - try: - return _CHAT_FORMATS[name] - except KeyError: - raise ValueError( - f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})" - ) + def parse( + self, model_name: str, messages: List[Dict[str, str]] + ) -> ChatFormatterResponse: + formatter = self.get_chat_format(model_name) + return formatter.parse_response(messages) +# NOTE: Template registration is currently a WIP (work in progress) @register_chat_format("alpaca") def format_alpaca( messages: List[llama_types.ChatCompletionRequestMessage], From 9a3434121cf0d2e8105132312c00717b5b050fc4 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Fri, 13 Oct 2023 00:27:55 -0400 Subject: [PATCH 18/22] feat: Enhance Chat Formatting and Template Customization - Added introductory comment explaining module purpose - Defined default templates for HuggingFace and common roles - Introduced Llama2Formatter and AlpacaFormatter classes - Registered predefined chat format models - Implemented ChatFormatterFactory for managing formatters These changes enhance the flexibility and customization of chat formatting, allowing for the registration of custom formatters and providing default templates for different chat models. --- llama_cpp/llama_chat_format.py | 328 +++++++++++++++++++++++---------- 1 file changed, 227 insertions(+), 101 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 72677f095..036691aed 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,41 +1,182 @@ +""" +llama_cpp/llama_chat_format.py + +This module provides a chat formatting system that allows for custom templates and HuggingFace's jinja2-based chat templating. + +To extend or customize, simply inherit from the ChatFormatter class and override the necessary methods. Registered formatters can be accessed using the ChatFormatterFactory. + +NOTE: The system message is always assumed to be the first element in a sequence. + +# Usage example: +# Registering a custom formatter +@ChatFormatterFactory.register_predefined_model("llama-2") +class Llama2Formatter(ChatFormatter): + def __init__(self): + super().__init__(llama2_template) + +# Obtaining a registered formatter +chat_formatter_factory = ChatFormatterFactory() +llama2_formatter = chat_formatter_factory.get_formatter_by_name("alpaca") + +# Formatting messages +messages = [{"role": "user", "content": "Hello, World!"}] +response = llama2_formatter(messages) +print(response) +""" import dataclasses -from typing import Dict, List, Optional +import os +from typing import Any, Dict, List, Optional, Protocol, Type, Union +from huggingface_hub import login from transformers import AutoTokenizer from . import llama_types -# NOTE: Custom Templates use Jinja2. -# If no template is given, then should default to hf's tokenizer template. -# We can define the model and template on a model-to-model basis, -# however, this should be allowed to be overridden for flexibility and extensibility. -# We only need 2 keys, the model name and the jinja2 template. -# -# template = {"model": "meta-llama/Llama-2-7b-chat-hf", "template": None} -# -# or -# -# chat_template = { -# "model": "meta-llama/Llama-2-7b-chat-hf", -# "jinja": "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + message['content'].strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ '[ASST] ' + message ['content'] + ' [/ASST]' + eos_token }}{% endif %}{% endfor %}", -# } -# -# We can probably employ some kind of method for reading a template it in from a file in necessary. -# -# We leave template empty here because HuggingFace defined it already. -# -# Source: https://huggingface.co/docs/transformers/main/chat_templating -# -# Special Thanks and Credit goes to bioshazard for the idea and preliminary implementation. -# Source: https://github.com/abetlen/llama-cpp-python/pull/790 - - -# NOTE: We can still use this for reverse compatibility with the currently employed API. -# This can be modified, if needed, in the future. +# NOTE: The default templates are defined here for reusability. +huggingface_template = { + "model": "meta-llama/Llama-2-7b-chat-hf", + "jinja": None, + "tokenize": False, +} + +common_template = { + "separators": { + "after_system": "\n", + "between_messages": "\n", + "end_of_response": "", + }, + "special_tokens": { + "bos_token": "", + "eos_token": "", + "unk_token": "", + }, + "default_termination": { + "role": "assistant", + "message": None, + }, + "include_prompt": False, +} + +# Templates can be reused, modified, or overriden as needed on a model-by-model basis. +# This reduces noise in the code and ideally keeps the code base DRY. +llama2_template = { + "roles": { + "system": { + "prefix": "<>", + "postfix": "<>", + "format": None, # Optionally specify an custom format + }, + "assistant": { + "prefix": "", # No prefix for assistant role by default + "postfix": "", # No postfix for assistant role by default + "format": None, + }, + "user": { + "prefix": "[INST] ", + "postfix": " [/INST]", # Model starts generating from here + "format": None, + }, + } +} +# NOTE: The merge operator requires Python 3.9+ +# Other options are to use `dict.update()` or to create a custom function that merges them. +# Source: https://docs.python.org/3/library/stdtypes.html?highlight=dict#dict +llama2_template |= common_template + +# NOTE: If `include_prompt` is set to `True`, it will append the user prefix/postfix to the prompts output. +alpaca_template = { + "roles": { + "system": { + "prefix": "", + "postfix": "\n", + "format": None, + }, + "user": { + "prefix": "### Instruction:\n", + "postfix": "\n", + "format": None, + }, + "input": { + "prefix": "### Input:\n", + "postfix": "\n", + "format": None, + }, + "assistant": { + "prefix": "### Response:\n", + "postfix": "", # Model starts generating from here + "format": None, + }, + } +} +alpaca_template |= common_template + + @dataclasses.dataclass class ChatFormatterResponse: prompt: str - stop: Optional[List[str]] = None + stop: Optional[Union[str, List[str]]] = None + + +# Base Chat Formatter Protocol +class ChatFormatterInterface(Protocol): + def __init__(self, template: Optional[Dict[str, Any]] = None): + raise NotImplementedError + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs, + ) -> ChatFormatterResponse: + raise NotImplementedError + + +# Core Chat Formatter class +# NOTE: Methods can be overridden as needed on a model-by-model basis. +class ChatFormatter(ChatFormatterInterface): + def __init__(self, template: Optional[Dict[str, Any]] = None): + self.template = template or llama2_template + + def __call__( + self, + messages: List[Dict[str, str]], + **kwargs: Any, + ) -> ChatFormatterResponse: + formatted_messages = [ + self.format_message(msg["content"], msg["role"]) for msg in messages + ] + separator = self.format_separator("between_messages") + formatted_sequence = separator.join(formatted_messages) + # NOTE: Optionally include a prompt at the end + if self.template["include_prompt"]: + formatted_sequence += self.get_prompt() + # NOTE: `stop` is handled within completion methods + return ChatFormatterResponse(prompt=formatted_sequence) + + def format_message(self, message, role) -> str: + """Format a message based on the specified role.""" + try: + role_info = self.template["roles"][role] + except KeyError: + raise KeyError( + f"The role '{role}' is not defined in the template. Please check your template configuration." + ) + + prefix = role_info.get("prefix", "") + postfix = role_info.get("postfix", "") + formatted_message = f"{prefix}{message}{postfix}" + return formatted_message + + def format_separator(self, separator_type) -> str: + """Format separators based on the specified type.""" + return self.template["separators"].get(separator_type, "") + + def format_special_token(self, token_type) -> str: + """Format special tokens based on the specified type.""" + return self.template["special_tokens"].get(token_type, "") + + def get_prompt(self) -> str: + # Implement logic to generate a prompt, if needed + return self.template["roles"]["user"]["prefix"] class TokenizerCache: @@ -48,19 +189,31 @@ def get_tokenizer(cls, model_name: str) -> AutoTokenizer: return cls._cache[model_name] -class ChatFormatterTemplate: +class AutoTokenizerFormatter(ChatFormatterInterface): def __init__(self, template: Optional[Dict[str, str]] = None): - if template: - self.template = template - else: - self.template = { - "model": "meta-llama/Llama-2-7b-chat-hf", - "jinja": None, - "tokenize": False, - } + self.template = template or huggingface_template + self.token = os.getenv("HF_TOKEN") + if self.token is None: + raise AttributeError( + "Failed to login to huggingface. " + "Did you forget to set the `HF_TOKEN` environment variable with your huggingface token?" + ) + login(self.token) self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"]) - def _format_messages(self, messages: List[Dict[str, str]]) -> str: + def __call__( + self, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs, + ) -> ChatFormatterResponse: + formatted_content = self.format_messages(messages) + return ChatFormatterResponse( + prompt=formatted_content, stop=[self.tokenizer.eos_token] + ) + + def format_messages( + self, messages: List[llama_types.ChatCompletionRequestMessage] + ) -> str: # If a custom template is provided, override the tokenizer's default template if self.template.get("jinja"): self.tokenizer.chat_template = self.template["jinja"] @@ -69,51 +222,50 @@ def _format_messages(self, messages: List[Dict[str, str]]) -> str: messages, tokenize=self.template["tokenize"] ) - def parse_response(self, messages: List[Dict[str, str]]) -> ChatFormatterResponse: - formatted_content = self._format_messages(messages) - return ChatFormatterResponse( - prompt=formatted_content, stop=[self.tokenizer.eos_token] - ) +# NOTE: Template registration is currently a WIP (work in progress). +class FormatterNotFoundException(Exception): + pass -class ChatFormatter: - _chat_formatters: Dict[str, ChatFormatterTemplate] = {} - def register_chat_format( - self, model_name: str, template: Optional[Dict[str, str]] = None - ): - self._chat_formatters[model_name] = ChatFormatterTemplate(template) +# External developers can now use the `@ChatFormatter.register_predefined_model` +# method to register their own custom formatters. +class ChatFormatterFactory: + _chat_formatters: Dict[str, ChatFormatterInterface] = {} - def get_chat_format(self, model_name: str) -> ChatFormatterTemplate: - if model_name not in self._chat_formatters: - raise ValueError(f"Model {model_name} is not registered.") + @staticmethod + def register_predefined_model(name: str): + def decorator(cls: Type[ChatFormatterInterface]): + ChatFormatterFactory._chat_formatters[name] = cls() + return cls - return self._chat_formatters[model_name] + return decorator - def format(self, model_name: str, messages: List[Dict[str, str]]) -> str: - formatter = self.get_chat_format(model_name) - return formatter._format_messages(messages) + def register_custom_model(self, name: str, formatter: ChatFormatterInterface): + self._chat_formatters[name] = formatter - def parse( - self, model_name: str, messages: List[Dict[str, str]] - ) -> ChatFormatterResponse: - formatter = self.get_chat_format(model_name) - return formatter.parse_response(messages) + def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: + try: + return self._chat_formatters[name] + except KeyError: + raise FormatterNotFoundException( + f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" + ) -# NOTE: Template registration is currently a WIP (work in progress) -@register_chat_format("alpaca") -def format_alpaca( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _roles = dict(user="### Instruction", assistant="### Response") - _sep = "\n\n" - _sep2 = "" - system_message = _get_system_message(messages) - _messages = _map_roles(messages, _roles) - _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) - return ChatFormatterResponse(prompt=_prompt) +# Define a chat format class and register it +@ChatFormatterFactory.register_predefined_model("llama-2") +class Llama2Formatter(ChatFormatter): + def __init__(self): + super().__init__(llama2_template) + + +# Define a chat format class and register it +@ChatFormatterFactory.register_predefined_model("alpaca") +class AlpacaFormatter(ChatFormatter): + def __init__(self): + # Define the Alpaca template + super().__init__(alpaca_template) @register_chat_format("vicuna") @@ -269,32 +421,6 @@ def format_chatml( return ChatFormatterResponse(prompt=_prompt) -# eg, export HF_MODEL=mistralai/Mistral-7B-Instruct-v0.1 -@register_chat_format("autotokenizer") -def format_autotokenizer( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - # https://huggingface.co/docs/transformers/main/chat_templating - # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format - # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json - import os - - from transformers import AutoTokenizer - - huggingFaceModel = os.getenv("HF_MODEL") # eg, mistralai/Mistral-7B-Instruct-v0.1 - print(huggingFaceModel) - if not huggingFaceModel: - raise Exception( - "HF_MODEL needs to be set in env to use chat format 'autotokenizer'" - ) - tokenizer = AutoTokenizer.from_pretrained(huggingFaceModel) - tokenizer.use_default_system_prompt = False - _prompt = tokenizer.apply_chat_template(messages, tokenize=False) - # Return formatted prompt and eos token by default - return ChatFormatterResponse(prompt=_prompt, stop=tokenizer.eos_token) - - @register_chat_format("functionary") def format_functionary( messages: List[llama_types.ChatCompletionRequestMessage], From 9d8db9cc2dd80b4d0441d51e5d8354ad1834225a Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:20:09 -0400 Subject: [PATCH 19/22] refactor: Improve HuggingFace Login and Formatter Name - Consolidated and isolated the HuggingFace login process for improved security. - Used the module name `huggingface_hub` instead of directly importing the `login` function for clarity. - Corrected the formatter name to "llama-2" for consistency. These changes enhance security by isolating the login process and improve code clarity by using the module name for HuggingFace operations. --- llama_cpp/llama_chat_format.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 036691aed..07bd8292f 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -16,7 +16,7 @@ def __init__(self): # Obtaining a registered formatter chat_formatter_factory = ChatFormatterFactory() -llama2_formatter = chat_formatter_factory.get_formatter_by_name("alpaca") +llama2_formatter = chat_formatter_factory.get_formatter_by_name("llama-2") # Formatting messages messages = [{"role": "user", "content": "Hello, World!"}] @@ -27,7 +27,7 @@ def __init__(self): import os from typing import Any, Dict, List, Optional, Protocol, Type, Union -from huggingface_hub import login +import huggingface_hub from transformers import AutoTokenizer from . import llama_types @@ -192,13 +192,7 @@ def get_tokenizer(cls, model_name: str) -> AutoTokenizer: class AutoTokenizerFormatter(ChatFormatterInterface): def __init__(self, template: Optional[Dict[str, str]] = None): self.template = template or huggingface_template - self.token = os.getenv("HF_TOKEN") - if self.token is None: - raise AttributeError( - "Failed to login to huggingface. " - "Did you forget to set the `HF_TOKEN` environment variable with your huggingface token?" - ) - login(self.token) + self.huggingface_login() self.tokenizer = TokenizerCache.get_tokenizer(self.template["model"]) def __call__( @@ -211,6 +205,15 @@ def __call__( prompt=formatted_content, stop=[self.tokenizer.eos_token] ) + def huggingface_login(self) -> None: + token = os.getenv("HF_TOKEN") + if token is None: + raise AttributeError( + "Failed to login to huggingface. " + "Did you forget to set the `HF_TOKEN` environment variable with your huggingface token?" + ) + huggingface_hub.login(token) + def format_messages( self, messages: List[llama_types.ChatCompletionRequestMessage] ) -> str: @@ -219,7 +222,7 @@ def format_messages( self.tokenizer.chat_template = self.template["jinja"] return self.tokenizer.apply_chat_template( - messages, tokenize=self.template["tokenize"] + messages, tokenize=self.template.get("tokenize", False) ) From 0ebfc1f1c0bdeffd7f46c85df1d4434e3f127995 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 14 Oct 2023 00:04:30 -0400 Subject: [PATCH 20/22] Refactor: Update important notes, chat templates, and code in llama_chat_format.py - Updated the important notes section with clear and concise information about special tokens and Python version compatibility. - Anonymized the example templates by replacing names with "Llama" and "User" for clarity. - Made formatting changes to improve code readability and organization in llama_chat_format.py. - Added the Vicuna model template. This commit enhances the clarity of important notes, anonymizes example templates, and improves code formatting in llama_chat_format.py. --- llama_cpp/llama_chat_format.py | 122 +++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 29 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 07bd8292f..75b15c965 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -7,6 +7,54 @@ NOTE: The system message is always assumed to be the first element in a sequence. +NOTE: Users should avoid tampering with special tokens to prevent model issues. + +--- + +# IMPORTANT NOTES: + +- The use of the merge operator (|) for dictionaries requires Python 3.9 or higher. Keep in mind that llama-cpp-python supports Python 3.8 and later versions. If you are working with an earlier Python version, consider alternatives such as `dict.update()` or creating a custom function to merge dictionaries. For Python 3.9 or higher, the merge operator simplifies dictionary merging. +Source: https://docs.python.org/3/library/stdtypes.html?highlight=dict#dict + +- Special tokens are crucial for the model's underlying operations, impacting pre-training, fine-tuning, and low-level inference processes. Users should avoid modifying special tokens to prevent issues in the model's output during inference. These issues may manifest as token fixation, repetitive language patterns, contextual derailment, and hallucinations. Improper use of separators and templates can exacerbate these problems. + +Example using the llama-2 model and its templating schema: + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 3 Hello User, my name is Llama. Nice to meet you!$ +# 4 [INST] What can you do? [/INST]$ +# 5 I can assist you with various tasks, including providing structured output for certain queries.$ +# 6 [INST] How can you assist me in my programming projects? [/INST]$ +# 7 $ + +This initial example is a proper template format that the model understands. It results in proper output and does not confuse the model. + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 3 Hello User, my name is Llama. Nice to meet you!$ +# 4 [INST] What can you do? [/INST]$ +# 5 I can assist you with various tasks, including providing structured output for certain queries.$ +# 6 [INST] How can you assist me in my programming projects? [/INST]$ +# 7 $ + +This example includes the use of special tokens, and the model may or may not use these tokens as a result. The model is not expecting them during inference, which causes unexpected behavior. + +# 1 <>My name is Llama and I am a helpful assistant.<>$ +# 2 $ +# 3 [INST] Hello Llama, my name is User. What's your name? [/INST]$ +# 4 Hello User, my name is Llama. Nice to meet you!$ +# 5 $ +# 6 [INST] What can you do? [/INST]$ +# 7 I can assist you with various tasks, including providing structured output for certain queries.$ +# 8 $ +# 9 [INST] How can you assist me in my programming projects? [/INST]$ +# 10 $ + +This example is improperly formatted and causes the model to become confused. The model begins to fixate on tokens, uses language repetition, and eventually derails. + +--- + # Usage example: # Registering a custom formatter @ChatFormatterFactory.register_predefined_model("llama-2") @@ -32,58 +80,54 @@ def __init__(self): from . import llama_types -# NOTE: The default templates are defined here for reusability. +# Default chat formatting templates for reusability. +# These templates can be reused or modified on a model-by-model basis. + +# Template for HuggingFace-based models. huggingface_template = { "model": "meta-llama/Llama-2-7b-chat-hf", "jinja": None, "tokenize": False, } +# Common formatting settings applicable to all roles in chat models. common_template = { "separators": { "after_system": "\n", "between_messages": "\n", "end_of_response": "", }, - "special_tokens": { - "bos_token": "", - "eos_token": "", - "unk_token": "", - }, "default_termination": { - "role": "assistant", - "message": None, + "role": "assistant", # Default role for termination + "message": None, # Default termination message (None for assistant) }, - "include_prompt": False, + "include_prompt": False, # Whether to include user prefix/postfix in prompts } -# Templates can be reused, modified, or overriden as needed on a model-by-model basis. -# This reduces noise in the code and ideally keeps the code base DRY. +# Template for Llama-2 model. llama2_template = { "roles": { "system": { - "prefix": "<>", - "postfix": "<>", - "format": None, # Optionally specify an custom format - }, - "assistant": { - "prefix": "", # No prefix for assistant role by default - "postfix": "", # No postfix for assistant role by default - "format": None, + "prefix": "<>", # System message prefix + "postfix": "<>", # System message postfix + "format": None, # Optionally specify a custom format }, "user": { "prefix": "[INST] ", - "postfix": " [/INST]", # Model starts generating from here + "postfix": " [/INST]", # Model generates from here "format": None, }, + "assistant": { + "prefix": "", # No prefix for assistant role by default + "postfix": "", # No postfix for assistant role by default + "format": None, # Custom format for assistant (if needed) + }, } } -# NOTE: The merge operator requires Python 3.9+ -# Other options are to use `dict.update()` or to create a custom function that merges them. -# Source: https://docs.python.org/3/library/stdtypes.html?highlight=dict#dict +# Merge common settings into the llama2_template to reduce code duplication. llama2_template |= common_template -# NOTE: If `include_prompt` is set to `True`, it will append the user prefix/postfix to the prompts output. +# Template for Alpaca model. alpaca_template = { "roles": { "system": { @@ -103,13 +147,37 @@ def __init__(self): }, "assistant": { "prefix": "### Response:\n", - "postfix": "", # Model starts generating from here + "postfix": "", # Model generates from here "format": None, }, } } +# Merge common settings into the alpaca_template to reduce code duplication. alpaca_template |= common_template +# Template for Vicuna model. +vicuna_template = { + "roles": { + "system": { + "prefix": "", + "postfix": "\n", + "format": None, + }, + "user": { + "prefix": "USER: ", + "postfix": "", + "format": None, + }, + "assistant": { + "prefix": "ASSISTANT: ", # Model generates from here + "postfix": "", + "format": None, + }, + } +} +# Merge common settings into the alpaca_template to reduce code duplication. +vicuna_template |= common_template + @dataclasses.dataclass class ChatFormatterResponse: @@ -170,10 +238,6 @@ def format_separator(self, separator_type) -> str: """Format separators based on the specified type.""" return self.template["separators"].get(separator_type, "") - def format_special_token(self, token_type) -> str: - """Format special tokens based on the specified type.""" - return self.template["special_tokens"].get(token_type, "") - def get_prompt(self) -> str: # Implement logic to generate a prompt, if needed return self.template["roles"]["user"]["prefix"] From 97a117f9419c51a5d09525dba0811bccf4608d9d Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 14 Oct 2023 01:08:48 -0400 Subject: [PATCH 21/22] Refactor: Update chat templates, code, and types in llama_chat_format.py - Added basic type definitions for better code clarity. - Removed repetitive comments in the code. - Added a note about the Vicuna template being version 1.5 and differing from v0. - Applied new type definitions to chat templates. - Introduced a new VicunaFormatter class to replace the older one, improving code readability. This commit enhances code clarity, maintains consistency, and improves the structure of the codebase. --- llama_cpp/llama_chat_format.py | 31 +++++++++++-------------------- llama_cpp/llama_types.py | 21 +++++++++++++++++++-- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 75b15c965..7515a9536 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -91,7 +91,7 @@ def __init__(self): } # Common formatting settings applicable to all roles in chat models. -common_template = { +common_template: llama_types.CommonTemplate = { "separators": { "after_system": "\n", "between_messages": "\n", @@ -105,7 +105,7 @@ def __init__(self): } # Template for Llama-2 model. -llama2_template = { +llama2_template: llama_types.ChatMLTemplate = { "roles": { "system": { "prefix": "<>", # System message prefix @@ -128,7 +128,7 @@ def __init__(self): llama2_template |= common_template # Template for Alpaca model. -alpaca_template = { +alpaca_template: llama_types.ChatMLTemplate = { "roles": { "system": { "prefix": "", @@ -152,11 +152,12 @@ def __init__(self): }, } } -# Merge common settings into the alpaca_template to reduce code duplication. alpaca_template |= common_template # Template for Vicuna model. -vicuna_template = { +# NOTE: The v0 template differs from the v1.1, v1.3, and v1.5. +# This is the v1.5 Vicuna Template. +vicuna_template: llama_types.ChatMLTemplate = { "roles": { "system": { "prefix": "", @@ -175,7 +176,6 @@ def __init__(self): }, } } -# Merge common settings into the alpaca_template to reduce code duplication. vicuna_template |= common_template @@ -335,20 +335,11 @@ def __init__(self): super().__init__(alpaca_template) -@register_chat_format("vicuna") -def format( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." - _roles = dict(user="USER", assistant="ASSISTANT") - _sep = " " - _sep2 = "" - system_message = _system_message - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) - return ChatFormatterResponse(prompt=_prompt) +@ChatFormatterFactory.register_predefined_model("vicuna") +class VicunaFormatter(ChatFormatter): + def __init__(self): + # Define the Vicuna template + super().__init__(vicuna_template) @register_chat_format("oasst_llama") diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 6ee7ef914..70c79c321 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -4,8 +4,9 @@ https://github.com/openai/openai-openapi/blob/master/openapi.yaml """ -from typing import Any, List, Optional, Dict, Union -from typing_extensions import TypedDict, NotRequired, Literal +from typing import Any, Dict, List, Optional, Union + +from typing_extensions import Literal, NotRequired, TypedDict class EmbeddingUsage(TypedDict): @@ -170,3 +171,19 @@ class ChatCompletionRequestMessage(TypedDict): content: Optional[str] name: NotRequired[str] funcion_call: NotRequired[ChatCompletionFunctionCall] + + +class RoleTemplate(TypedDict, total=False): + prefix: str + postfix: str + format: Optional[str] + + +class CommonTemplate(TypedDict): + separators: Dict[str, str] + default_termination: Dict[str, Optional[str]] + include_prompt: bool + + +class ChatMLTemplate(TypedDict): + roles: Dict[str, RoleTemplate] From a9901c49959b5028719c55700015fda252f7df42 Mon Sep 17 00:00:00 2001 From: teleprint-me <77757836+teleprint-me@users.noreply.github.com> Date: Sat, 14 Oct 2023 11:41:56 -0400 Subject: [PATCH 22/22] Refactor: Replace Open Assistant Hybrid chat template with classical template in llama_chat_format.py - Replaced the Open Assistant Hybrid chat template with the classical template. - Added the original Open Assistant chat template for non-hybrid models. This commit streamlines the chat templates by using the classical template for Open Assistant and adds the original template for non-hybrid models, reducing duplication. --- llama_cpp/llama_chat_format.py | 44 +++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 7515a9536..a3e850ef6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -178,6 +178,30 @@ def __init__(self): } vicuna_template |= common_template +# NOTE: Open Assistant uses multiple custom prompts. +# The oasst-llama hybrids utilize ChatML templates. +# The base template is defined here for convenience. +oasst_template: llama_types.ChatMLTemplate = { + "roles": { + "system": { + "prefix": "<|system|>", + "postfix": "<|endoftext|>", + "format": None, + }, + "user": { + "prefix": "<|prompter|>", + "postfix": "<|endoftext|>", + "format": None, + }, + "assistant": { + "prefix": "<|assistant|>", # Model generates from here + "postfix": "<|endoftext|>", + "format": None, + }, + } +} +oasst_template |= common_template + @dataclasses.dataclass class ChatFormatterResponse: @@ -342,20 +366,12 @@ def __init__(self): super().__init__(vicuna_template) -@register_chat_format("oasst_llama") -def format_oasst_llama( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - _system_template = "[INST] <>\n{system_message}\n<>\n\n" - _roles = dict(user="<|prompter|>", assistant="<|assistant|>") - _sep = "" - system_message = _get_system_message(messages) - system_message = _system_template.format(system_message=system_message) - _messages = _map_roles(messages, _roles) - _messages.append((_roles["assistant"], None)) - _prompt = _format_no_colon_single(system_message, _messages, _sep) - return ChatFormatterResponse(prompt=_prompt) +# NOTE: Refer to `oasst_template` note for more information. +@ChatFormatterFactory.register_predefined_model("oasst") +class OpenAssistantFormatter(ChatFormatter): + def __init__(self): + # Define the Open Assistant template + super().__init__(oasst_template) @register_chat_format("openbuddy")