21
21
import diskcache
22
22
import ctypes
23
23
24
- from . import llama_cpp
25
24
from .llama_types import *
26
25
from .llama_grammar import LlamaGrammar
26
+ import llama_cpp .llama_cpp as llama_cpp
27
27
import llama_cpp .llama_chat_format as llama_chat_format
28
28
29
29
import numpy as np
@@ -752,6 +752,7 @@ def __init__(
752
752
numa : bool = False ,
753
753
# Chat Format Params
754
754
chat_format : str = "llama-2" ,
755
+ chat_handler : Optional [llama_chat_format .LlamaChatCompletionHandler ] = None ,
755
756
# Misc
756
757
verbose : bool = True ,
757
758
# Extra Params
@@ -784,6 +785,7 @@ def __init__(
784
785
lora_path: Path to a LoRA file to apply to the model.
785
786
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
786
787
chat_format: String specifying the chat format to use when calling create_chat_completion.
788
+ chat_handler: Optional chat handler to use when calling create_chat_completion.
787
789
verbose: Print verbose output to stderr.
788
790
789
791
Raises:
@@ -910,6 +912,7 @@ def __init__(
910
912
print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
911
913
912
914
self .chat_format = chat_format
915
+ self .chat_handler = chat_handler
913
916
914
917
self ._n_vocab = self .n_vocab ()
915
918
self ._n_ctx = self .n_ctx ()
@@ -1231,7 +1234,7 @@ def create_embedding(
1231
1234
else :
1232
1235
inputs = input
1233
1236
1234
- data : List [EmbeddingData ] = []
1237
+ data : List [Embedding ] = []
1235
1238
total_tokens = 0
1236
1239
for index , input in enumerate (inputs ):
1237
1240
tokens = self .tokenize (input .encode ("utf-8" ), special = True )
@@ -1276,7 +1279,7 @@ def embed(self, input: str) -> List[float]:
1276
1279
1277
1280
def _create_completion (
1278
1281
self ,
1279
- prompt : str ,
1282
+ prompt : Union [ str , List [ int ]] ,
1280
1283
suffix : Optional [str ] = None ,
1281
1284
max_tokens : int = 16 ,
1282
1285
temperature : float = 0.8 ,
@@ -1297,7 +1300,9 @@ def _create_completion(
1297
1300
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1298
1301
logits_processor : Optional [LogitsProcessorList ] = None ,
1299
1302
grammar : Optional [LlamaGrammar ] = None ,
1300
- ) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
1303
+ ) -> Union [
1304
+ Iterator [CreateCompletionResponse ], Iterator [CreateCompletionStreamResponse ]
1305
+ ]:
1301
1306
assert self ._ctx is not None
1302
1307
assert suffix is None or suffix .__class__ is str
1303
1308
@@ -1309,7 +1314,7 @@ def _create_completion(
1309
1314
self .tokenize (prompt .encode ("utf-8" ), special = True )
1310
1315
if prompt != ""
1311
1316
else [self .token_bos ()]
1312
- )
1317
+ ) if isinstance ( prompt , str ) else prompt
1313
1318
text : bytes = b""
1314
1319
returned_tokens : int = 0
1315
1320
stop = (
@@ -1322,7 +1327,7 @@ def _create_completion(
1322
1327
1323
1328
if len (prompt_tokens ) >= self ._n_ctx :
1324
1329
raise ValueError (
1325
- f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { llama_cpp .llama_n_ctx (self ._ctx )} "
1330
+ f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { llama_cpp .llama_n_ctx (self .ctx )} "
1326
1331
)
1327
1332
1328
1333
if max_tokens <= 0 :
@@ -1732,7 +1737,7 @@ def _create_completion(
1732
1737
1733
1738
def create_completion (
1734
1739
self ,
1735
- prompt : str ,
1740
+ prompt : Union [ str , List [ int ]] ,
1736
1741
suffix : Optional [str ] = None ,
1737
1742
max_tokens : int = 128 ,
1738
1743
temperature : float = 0.8 ,
@@ -1753,7 +1758,7 @@ def create_completion(
1753
1758
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1754
1759
logits_processor : Optional [LogitsProcessorList ] = None ,
1755
1760
grammar : Optional [LlamaGrammar ] = None ,
1756
- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1761
+ ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
1757
1762
"""Generate text from a prompt.
1758
1763
1759
1764
Args:
@@ -1800,7 +1805,7 @@ def create_completion(
1800
1805
grammar = grammar ,
1801
1806
)
1802
1807
if stream :
1803
- chunks : Iterator [CompletionChunk ] = completion_or_chunks
1808
+ chunks : Iterator [CreateCompletionStreamResponse ] = completion_or_chunks
1804
1809
return chunks
1805
1810
completion : Completion = next (completion_or_chunks ) # type: ignore
1806
1811
return completion
@@ -1828,7 +1833,7 @@ def __call__(
1828
1833
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1829
1834
logits_processor : Optional [LogitsProcessorList ] = None ,
1830
1835
grammar : Optional [LlamaGrammar ] = None ,
1831
- ) -> Union [Completion , Iterator [CompletionChunk ]]:
1836
+ ) -> Union [CreateCompletionResponse , Iterator [CreateCompletionStreamResponse ]]:
1832
1837
"""Generate text from a prompt.
1833
1838
1834
1839
Args:
@@ -1879,7 +1884,9 @@ def create_chat_completion(
1879
1884
self ,
1880
1885
messages : List [ChatCompletionRequestMessage ],
1881
1886
functions : Optional [List [ChatCompletionFunction ]] = None ,
1882
- function_call : Optional [Union [str , ChatCompletionFunctionCall ]] = None ,
1887
+ function_call : Optional [ChatCompletionRequestFunctionCall ] = None ,
1888
+ tools : Optional [List [ChatCompletionTool ]] = None ,
1889
+ tool_choice : Optional [ChatCompletionToolChoiceOption ] = None ,
1883
1890
temperature : float = 0.2 ,
1884
1891
top_p : float = 0.95 ,
1885
1892
top_k : int = 40 ,
@@ -1896,7 +1903,9 @@ def create_chat_completion(
1896
1903
model : Optional [str ] = None ,
1897
1904
logits_processor : Optional [LogitsProcessorList ] = None ,
1898
1905
grammar : Optional [LlamaGrammar ] = None ,
1899
- ) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1906
+ ) -> Union [
1907
+ CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]
1908
+ ]:
1900
1909
"""Generate a chat completion from a list of messages.
1901
1910
1902
1911
Args:
@@ -1912,12 +1921,16 @@ def create_chat_completion(
1912
1921
Returns:
1913
1922
Generated chat completion or a stream of chat completion chunks.
1914
1923
"""
1915
- handler = llama_chat_format .get_chat_completion_handler (self .chat_format )
1924
+ handler = self .chat_handler or llama_chat_format .get_chat_completion_handler (
1925
+ self .chat_format
1926
+ )
1916
1927
return handler (
1917
- self ,
1928
+ llama = self ,
1918
1929
messages = messages ,
1919
1930
functions = functions ,
1920
1931
function_call = function_call ,
1932
+ tools = tools ,
1933
+ tool_choice = tool_choice ,
1921
1934
temperature = temperature ,
1922
1935
top_p = top_p ,
1923
1936
top_k = top_k ,
@@ -1974,6 +1987,7 @@ def __getstate__(self):
1974
1987
numa = self .numa ,
1975
1988
# Chat Format Params
1976
1989
chat_format = self .chat_format ,
1990
+ chat_handler = self .chat_handler ,
1977
1991
# Misc
1978
1992
verbose = self .verbose ,
1979
1993
)
@@ -2015,6 +2029,7 @@ def __setstate__(self, state):
2015
2029
numa = state ["numa" ],
2016
2030
# Chat Format Params
2017
2031
chat_format = state ["chat_format" ],
2032
+ chat_handler = state ["chat_handler" ],
2018
2033
# Misc
2019
2034
verbose = state ["verbose" ],
2020
2035
)
0 commit comments