From 0552a74995a3338aab7e68f011e48cee07c9b076 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 14 Sep 2023 03:50:56 -0400 Subject: [PATCH 1/4] Add configurable default chat completion format. --- llama_cpp/llama.py | 82 +++++++++++++++----------- llama_cpp/llama_chat_templates.py | 98 +++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 34 deletions(-) create mode 100644 llama_cpp/llama_chat_templates.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5d093bef9..ebe75ca27 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -21,6 +21,8 @@ import diskcache import ctypes +from llama_cpp.llama_chat_templates import ChatCompletionFormat, DefaultChatCompletionFormat + from . import llama_cpp from .llama_types import * from .llama_grammar import LlamaGrammar @@ -30,6 +32,7 @@ from .utils import suppress_stdout_stderr + class BaseLlamaCache(ABC): """Base cache class for a llama.cpp model.""" @@ -237,8 +240,9 @@ def __init__( lora_base: Optional[str] = None, lora_path: Optional[str] = None, numa: bool = False, + chat_completion_template: Optional[ChatCompletionFormat] = None, verbose: bool = True, - **kwargs # type: ignore + **kwargs, # type: ignore ): """Load a llama.cpp model from `model_path`. @@ -290,7 +294,9 @@ def __init__( self.params = llama_cpp.llama_context_default_params() self.params.seed = seed self.params.n_ctx = n_ctx - self.params.n_gpu_layers = 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers # 0x7FFFFFFF is INT32 max, will be auto set to all layers + self.params.n_gpu_layers = ( + 0x7FFFFFFF if n_gpu_layers == -1 else n_gpu_layers + ) # 0x7FFFFFFF is INT32 max, will be auto set to all layers self.params.main_gpu = main_gpu self.params.rope_freq_base = rope_freq_base self.params.rope_freq_scale = rope_freq_scale @@ -314,10 +320,11 @@ def __init__( ) # keep a reference to the array so it is not gc'd self.params.tensor_split = self._c_tensor_split - self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) + self.chat_completion_template = chat_completion_template or DefaultChatCompletionFormat() + self.cache: Optional[BaseLlamaCache] = None self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) @@ -471,7 +478,9 @@ def detokenize(self, tokens: List[int]) -> bytes: output += bytes(buffer[:n]) # NOTE: Llama1 models automatically added a space at the start of the prompt # this line removes a leading space if the first token is a beginning of sentence token - return output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + return ( + output[1:] if len(tokens) > 0 and tokens[0] == self.token_bos() else output + ) def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. @@ -543,11 +552,7 @@ def _sample( n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = ( - n_ctx - if last_n_tokens_size < 0 - else last_n_tokens_size - ) + last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: @@ -608,7 +613,7 @@ def _sample( mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore m=mirostat_m, ) - elif mirostat_mode== 2: + elif mirostat_mode == 2: mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau) llama_cpp.llama_sample_temperature( ctx=self.ctx, @@ -898,7 +903,11 @@ def _create_completion( created: int = int(time.time()) completion_tokens: List[int] = [] # Add blank space to start of prompt to match OG llama tokenizer - prompt_tokens: List[int] = self.tokenize(prompt.encode("utf-8")) if prompt != "" else [self.token_bos()] + prompt_tokens: List[int] = ( + self.tokenize(prompt.encode("utf-8")) + if prompt != "" + else [self.token_bos()] + ) text: bytes = b"" returned_tokens: int = 0 stop = ( @@ -1023,7 +1032,9 @@ def _create_completion( for token in remaining_tokens: token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token - if token_end_position > (remaining_length - first_stop_position): + if token_end_position > ( + remaining_length - first_stop_position + ): break token_str = self.detokenize([token]).decode( "utf-8", errors="ignore" @@ -1080,7 +1091,7 @@ def _create_completion( for i in range(1, len(remaining_tokens) + 1): try: bs = self.detokenize(remaining_tokens[:i]) - ts = bs.decode('utf-8') + ts = bs.decode("utf-8") decode_success = True break except UnicodeError: @@ -1091,7 +1102,9 @@ def _create_completion( # all remaining tokens cannot be decoded to a UTF-8 character break token_end_position += len(bs) - if token_end_position > (remaining_length - first_stop_position): + if token_end_position > ( + remaining_length - first_stop_position + ): break remaining_tokens = remaining_tokens[i:] returned_tokens += i @@ -1396,7 +1409,7 @@ def create_completion( model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, - grammar=grammar + grammar=grammar, ) if stream: chunks: Iterator[CompletionChunk] = completion_or_chunks @@ -1534,6 +1547,18 @@ def _convert_text_completion_chunks_to_chat( ], } + def _convert_completion_to_chat( + self, + completion_or_chunks: Union[Completion, Iterator[CompletionChunk]], + stream: bool = False, + ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: + if stream: + chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore + return self._convert_text_completion_chunks_to_chat(chunks) + else: + completion: Completion = completion_or_chunks # type: ignore + return self._convert_text_completion_to_chat(completion) + def create_chat_completion( self, messages: List[ChatCompletionMessage], @@ -1571,26 +1596,20 @@ def create_chat_completion( Returns: Generated chat completion or a stream of chat completion chunks. """ - stop = ( - stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] - ) - chat_history = "".join( - f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' - for message in messages - ) - PROMPT = chat_history + "### Assistant:" - PROMPT_STOP = ["### Assistant:", "### Human:"] - completion_or_chunks = self( - prompt=PROMPT, - stop=PROMPT_STOP + stop, + completion_or_chunks = self.chat_completion_template.create_chat_completion( + self, + messages=messages, + functions=functions, + function_call=function_call, temperature=temperature, top_p=top_p, top_k=top_k, stream=stream, + stop=stop, max_tokens=max_tokens, - repeat_penalty=repeat_penalty, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, tfs_z=tfs_z, mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, @@ -1599,12 +1618,7 @@ def create_chat_completion( logits_processor=logits_processor, grammar=grammar, ) - if stream: - chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore - return self._convert_text_completion_chunks_to_chat(chunks) - else: - completion: Completion = completion_or_chunks # type: ignore - return self._convert_text_completion_to_chat(completion) + return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore def __del__(self): if hasattr(self, "model") and self.model is not None: diff --git a/llama_cpp/llama_chat_templates.py b/llama_cpp/llama_chat_templates.py new file mode 100644 index 000000000..e01a6151f --- /dev/null +++ b/llama_cpp/llama_chat_templates.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Iterator, List, Optional, Union +from llama_cpp.llama import Llama, LogitsProcessorList +from llama_cpp.llama_grammar import LlamaGrammar + +from llama_cpp.llama_types import ( + ChatCompletionFunction, + ChatCompletionFunctionCall, + ChatCompletionMessage, + Completion, + CompletionChunk, +) + + +class ChatCompletionFormat(ABC): + """Base class for chat completion templates.""" + + @abstractmethod + def create_chat_completion( + self, + llama: Llama, + messages: List[ChatCompletionMessage], + functions: Optional[List[ChatCompletionFunction]] = None, + function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + raise NotImplementedError + + +class DefaultChatCompletionFormat(ABC): + """Base class for chat completion templates.""" + + def create_chat_completion( + self, + llama: Llama, + messages: List[ChatCompletionMessage], + functions: Optional[List[ChatCompletionFunction]] = None, + function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + stop = ( + stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + ) + chat_history = "".join( + f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' + for message in messages + ) + PROMPT = chat_history + "### Assistant:" + PROMPT_STOP = ["### Assistant:", "### Human:"] + return llama.create_completion( + prompt=PROMPT, + stop=PROMPT_STOP + stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + max_tokens=max_tokens, + repeat_penalty=repeat_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) From 179f45c0be20b156f75e4dd8d854bffb6d4f656a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 14 Sep 2023 04:04:47 -0400 Subject: [PATCH 2/4] Remove chat_template file to avoid circular import --- llama_cpp/llama.py | 94 +++++++++++++++++++++++++++-- llama_cpp/llama_chat_templates.py | 98 ------------------------------- 2 files changed, 90 insertions(+), 102 deletions(-) delete mode 100644 llama_cpp/llama_chat_templates.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ebe75ca27..03c262fc8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -21,8 +21,6 @@ import diskcache import ctypes -from llama_cpp.llama_chat_templates import ChatCompletionFormat, DefaultChatCompletionFormat - from . import llama_cpp from .llama_types import * from .llama_grammar import LlamaGrammar @@ -240,7 +238,7 @@ def __init__( lora_base: Optional[str] = None, lora_path: Optional[str] = None, numa: bool = False, - chat_completion_template: Optional[ChatCompletionFormat] = None, + chat_completion_template: Optional["ChatCompletionFormat"] = None, verbose: bool = True, **kwargs, # type: ignore ): @@ -323,7 +321,9 @@ def __init__( self.last_n_tokens_size = last_n_tokens_size self.n_batch = min(n_ctx, n_batch) - self.chat_completion_template = chat_completion_template or DefaultChatCompletionFormat() + self.chat_completion_template = ( + chat_completion_template or DefaultChatCompletionFormat() + ) self.cache: Optional[BaseLlamaCache] = None @@ -1783,3 +1783,89 @@ def decode(self, tokens: List[int]) -> str: @classmethod def from_ggml_file(cls, path: str) -> "LlamaTokenizer": return cls(Llama(model_path=path, vocab_only=True)) + + +class ChatCompletionFormat(ABC): + """Base class for chat completion templates.""" + + @abstractmethod + def create_chat_completion( + self, + llama: Llama, + messages: List[ChatCompletionMessage], + functions: Optional[List[ChatCompletionFunction]] = None, + function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + raise NotImplementedError + + +class DefaultChatCompletionFormat(ABC): + """Base class for chat completion templates.""" + + def create_chat_completion( + self, + llama: Llama, + messages: List[ChatCompletionMessage], + functions: Optional[List[ChatCompletionFunction]] = None, + function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + max_tokens: int = 256, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + stop = ( + stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] + ) + chat_history = "".join( + f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' + for message in messages + ) + PROMPT = chat_history + "### Assistant:" + PROMPT_STOP = ["### Assistant:", "### Human:"] + return llama.create_completion( + prompt=PROMPT, + stop=PROMPT_STOP + stop, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + max_tokens=max_tokens, + repeat_penalty=repeat_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ) diff --git a/llama_cpp/llama_chat_templates.py b/llama_cpp/llama_chat_templates.py deleted file mode 100644 index e01a6151f..000000000 --- a/llama_cpp/llama_chat_templates.py +++ /dev/null @@ -1,98 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Iterator, List, Optional, Union -from llama_cpp.llama import Llama, LogitsProcessorList -from llama_cpp.llama_grammar import LlamaGrammar - -from llama_cpp.llama_types import ( - ChatCompletionFunction, - ChatCompletionFunctionCall, - ChatCompletionMessage, - Completion, - CompletionChunk, -) - - -class ChatCompletionFormat(ABC): - """Base class for chat completion templates.""" - - @abstractmethod - def create_chat_completion( - self, - llama: Llama, - messages: List[ChatCompletionMessage], - functions: Optional[List[ChatCompletionFunction]] = None, - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: - raise NotImplementedError - - -class DefaultChatCompletionFormat(ABC): - """Base class for chat completion templates.""" - - def create_chat_completion( - self, - llama: Llama, - messages: List[ChatCompletionMessage], - functions: Optional[List[ChatCompletionFunction]] = None, - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: - stop = ( - stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] - ) - chat_history = "".join( - f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' - for message in messages - ) - PROMPT = chat_history + "### Assistant:" - PROMPT_STOP = ["### Assistant:", "### Human:"] - return llama.create_completion( - prompt=PROMPT, - stop=PROMPT_STOP + stop, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream=stream, - max_tokens=max_tokens, - repeat_penalty=repeat_penalty, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ) From 428b64ec50128301a922bc5545994017e9ca05e8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 20 Sep 2023 03:30:04 -0400 Subject: [PATCH 3/4] Update llama_types --- llama_cpp/llama_types.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 21a0d4f1b..63e8af9b6 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -58,9 +58,11 @@ class Completion(TypedDict): class ChatCompletionMessage(TypedDict): - role: Literal["assistant", "user", "system"] - content: str + role: Literal["assistant", "user", "system", "function"] + content: Optional[str] user: NotRequired[str] + name: NotRequired[str] + function_call: NotRequired[str] class ChatCompletionFunction(TypedDict): @@ -71,6 +73,7 @@ class ChatCompletionFunction(TypedDict): class ChatCompletionFunctionCall(TypedDict): name: str + arguments: str class ChatCompletionChoice(TypedDict): From ae47e4fa5ff2f61ba3ff8dcf68daad7061da2e74 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 29 Sep 2023 19:45:18 -0400 Subject: [PATCH 4/4] Add chat format --- llama_cpp/llama.py | 114 +++---------- llama_cpp/llama_chat_format.py | 292 +++++++++++++++++++++++++++++++++ 2 files changed, 315 insertions(+), 91 deletions(-) create mode 100644 llama_cpp/llama_chat_format.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 0859fec8f..7c75b7f27 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -24,6 +24,7 @@ from . import llama_cpp from .llama_types import * from .llama_grammar import LlamaGrammar +from . import llama_chat_format import numpy as np import numpy.typing as npt @@ -243,6 +244,8 @@ def __init__( lora_path: Optional[str] = None, # Backend Params numa: bool = False, + # Chat Format Params + chat_format: str = "llama-2", # Misc verbose: bool = True, # Extra Params @@ -273,6 +276,7 @@ def __init__( lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. lora_path: Path to a LoRA file to apply to the model. numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) + chat_format: String specifying the chat format to use when calling create_chat_completion. verbose: Print verbose output to stderr. kwargs: Unused keyword arguments (for additional backwards compatibility). @@ -387,6 +391,8 @@ def __init__( if self.verbose: print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) + + self.chat_format = chat_format self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() @@ -1578,7 +1584,7 @@ def _convert_completion_to_chat( def create_chat_completion( self, - messages: List[ChatCompletionMessage], + messages: List[ChatCompletionRequestMessage], functions: Optional[List[ChatCompletionFunction]] = None, function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, temperature: float = 0.2, @@ -1613,11 +1619,19 @@ def create_chat_completion( Returns: Generated chat completion or a stream of chat completion chunks. """ - completion_or_chunks = self.chat_completion_template.create_chat_completion( - self, + + format = llama_chat_format.get_chat_format(self.chat_format) + result = format( messages=messages, - functions=functions, - function_call=function_call, + ) + prompt = result.prompt + if result.stop is not None: + stop = [] if stop is None else [stop] if isinstance(stop, str) else stop + rstop = result.stop if isinstance(result.stop, list) else [result.stop] + stop = stop + rstop + + completion_or_chunks = self.create_completion( + prompt=prompt, temperature=temperature, top_p=top_p, top_k=top_k, @@ -1675,6 +1689,8 @@ def __getstate__(self): lora_path=self.lora_path, # Backend Params numa=self.numa, + # Chat Format Params + chat_format=self.chat_format, # Misc verbose=self.verbose, ) @@ -1708,6 +1724,8 @@ def __setstate__(self, state): lora_path=state["lora_path"], # Backend Params numa=state["numa"], + # Chat Format Params + chat_format=state["chat_format"], # Misc verbose=state["verbose"], ) @@ -1821,89 +1839,3 @@ def decode(self, tokens: List[int]) -> str: @classmethod def from_ggml_file(cls, path: str) -> "LlamaTokenizer": return cls(Llama(model_path=path, vocab_only=True)) - - -class ChatCompletionFormat(ABC): - """Base class for chat completion templates.""" - - @abstractmethod - def create_chat_completion( - self, - llama: Llama, - messages: List[ChatCompletionMessage], - functions: Optional[List[ChatCompletionFunction]] = None, - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: - raise NotImplementedError - - -class DefaultChatCompletionFormat(ABC): - """Base class for chat completion templates.""" - - def create_chat_completion( - self, - llama: Llama, - messages: List[ChatCompletionMessage], - functions: Optional[List[ChatCompletionFunction]] = None, - function_call: Optional[Union[str, ChatCompletionFunctionCall]] = None, - temperature: float = 0.2, - top_p: float = 0.95, - top_k: int = 40, - stream: bool = False, - stop: Optional[Union[str, List[str]]] = [], - max_tokens: int = 256, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repeat_penalty: float = 1.1, - tfs_z: float = 1.0, - mirostat_mode: int = 0, - mirostat_tau: float = 5.0, - mirostat_eta: float = 0.1, - model: Optional[str] = None, - logits_processor: Optional[LogitsProcessorList] = None, - grammar: Optional[LlamaGrammar] = None, - ) -> Union[Completion, Iterator[CompletionChunk]]: - stop = ( - stop if isinstance(stop, list) else [stop] if isinstance(stop, str) else [] - ) - chat_history = "".join( - f'### {"Human" if message["role"] == "user" else "Assistant"}:{message["content"]}' - for message in messages - ) - PROMPT = chat_history + "### Assistant:" - PROMPT_STOP = ["### Assistant:", "### Human:"] - return llama.create_completion( - prompt=PROMPT, - stop=PROMPT_STOP + stop, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream=stream, - max_tokens=max_tokens, - repeat_penalty=repeat_penalty, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py new file mode 100644 index 000000000..bd8110f2c --- /dev/null +++ b/llama_cpp/llama_chat_format.py @@ -0,0 +1,292 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Union, Protocol +from . import llama_types + + +def _get_system_message( + messages: List[llama_types.ChatCompletionRequestMessage], +) -> str: + """Get the first system message.""" + for message in messages: + if message["role"] == "system": + return message["content"] or "" + return "" + + +def _map_roles( + messages: List[llama_types.ChatCompletionRequestMessage], role_map: Dict[str, str] +) -> List[Tuple[str, Optional[str]]]: + """Map the message roles.""" + output: List[Tuple[str, Optional[str]]] = [] + for message in messages: + role = message["role"] + if role in role_map: + output.append((role_map[role], message["content"])) + return output + + +def _format_llama2( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the llama2 style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += message + " " + else: + ret += role + " " + return ret + + +def _format_add_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ":" + return ret + + +def _format_add_colon_two( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str, sep2: str +) -> str: + """Format the prompt with the add-colon-two style.""" + seps = [sep, sep2] + ret = system_message + seps[0] + for i, (role, message) in enumerate(messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + + +def _format_no_colon_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the no-colon-single style.""" + ret = system_message + for role, message in messages: + if message: + ret += role + message + sep + else: + ret += role + return ret + + +def _format_add_colon_space_single( + system_message: str, messages: List[Tuple[str, Optional[str]]], sep: str +) -> str: + """Format the prompt with the add-colon-space-single style.""" + ret = system_message + sep + for role, message in messages: + if message: + ret += role + ": " + message + sep + else: + ret += role + ": " # must be end with a space + return ret + + +@dataclasses.dataclass +class ChatFormatterResponse: + prompt: str + stop: Optional[Union[str, List[str]]] = None + + +class ChatFormatter(Protocol): + def __call__( + self, + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, + ) -> ChatFormatterResponse: + ... + + +_CHAT_FORMATS: Dict[str, ChatFormatter] = {} + + +def register_chat_format(name: str): + def decorator(f: ChatFormatter): + _CHAT_FORMATS[name] = f + return f + + return decorator + + +def get_chat_format(name: str): + try: + return _CHAT_FORMATS[name] + except KeyError: + raise ValueError( + f"Invalid chat format: {name} (valid formats: {list(_CHAT_FORMATS.keys())})" + ) + + +@register_chat_format("llama-2") +def format_llama2( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_template = "[INST] <>\n{system_message}\n<>\n\n" + _roles = dict(user="[INST]", assistant="[/INST]") + _sep = "\n\n" + system_message = _get_system_message(messages) + system_message = _system_template.format(system_message=system_message) + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_llama2(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("alpaca") +def format_alpaca( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _roles = dict(user="### Instruction", assistant="### Response") + _sep = "\n\n" + _sep2 = "" + system_message = _get_system_message(messages) + _messages = _map_roles(messages, _roles) + _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("vicuna") +def format( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." + _roles = dict(user="USER", assistant="ASSISTANT") + _sep = " " + _sep2 = "" + system_message = _system_message + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_add_colon_two(system_message, _messages, _sep, _sep2) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("oasst_llama") +def format_oasst_llama( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_template = "[INST] <>\n{system_message}\n<>\n\n" + _roles = dict(user="<|prompter|>", assistant="<|assistant|>") + _sep = "" + system_message = _get_system_message(messages) + system_message = _system_template.format(system_message=system_message) + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_no_colon_single(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("openbuddy") +def format_openbuddy( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_message = """Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""" + _roles = dict(user="User", assistant="Assistant") + _sep = "\n" + system_message = _system_message + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_add_colon_single(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("redpajama-incite") +def format_redpajama_incite( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _system_message = _get_system_message(messages) + _roles = dict(user="", assistant="") + _sep = "\n" + _stop = "" + system_message = _system_message + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_add_colon_single(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt, stop=_stop) + + +@register_chat_format("snoozy") +def format_snoozy( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_template = "### Instruction:\n{system_message}" + default_system_message = "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." + _system_message = _get_system_message(messages) + _system_message = ( + _system_message if _system_message != "" else default_system_message + ) + system_message = system_template.format(system_message=_system_message) + _roles = dict(user="### Prompt", assistant="### Response") + _sep = "\n" + _stop = "###" + system_message = _system_message + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_add_colon_single(system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt, stop=_stop) + + +@register_chat_format("phind") +def format_phind( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + _roles = dict(user="### User Message", assistant="### Assistant") + _sep = "\n\n" + _system_message = "### System Prompt\nYou are an intelligent programming assistant." + _messages = _map_roles(messages, _roles) + _messages.append((_roles["assistant"], None)) + _prompt = _format_add_colon_single(_system_message, _messages, _sep) + return ChatFormatterResponse(prompt=_prompt) + + +@register_chat_format("open-orca") +def format_open_orca( + messages: List[llama_types.ChatCompletionRequestMessage], + **kwargs: Any, +) -> ChatFormatterResponse: + system_template = "{system_message}" + system_message = ( + "You are a helpful assistant. Please answer truthfully and write out your " + ) + "thinking step by step to be sure you get the right answer. If you make a mistake or encounter " + "an error in your thinking, say so out loud and attempt to correct it. If you don't know or " + "aren't sure about something, say so clearly. You will act as a professional logician, mathematician, " + "and physicist. You will also act as the most appropriate type of expert to answer any particular " + "question or solve the relevant problem; state which expert type your are, if so. Also think of " + "any particular named expert that would be ideal to answer the relevant question or solve the " + "relevant problem; name and act as them, if appropriate." + roles = ("User", "Assistant") + sep = "<|end_of_turn|>\n" + # stop_token_ids=[32000, 32001], # "<|end_of_turn|>" + stop_str = "User" + system_message = system_template.format(system_message=system_message) + _messages = _map_roles(messages, dict(zip(roles, roles))) + _messages.append((roles[1], None)) + _prompt = _format_add_colon_space_single(system_message, _messages, sep) + return ChatFormatterResponse(prompt=_prompt, stop=stop_str)