10
10
11
11
import jinja2
12
12
13
+ import numpy as np
14
+ import numpy .typing as npt
15
+
13
16
import llama_cpp .llama as llama
14
17
import llama_cpp .llama_types as llama_types
15
18
import llama_cpp .llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150
153
151
154
prompt : str
152
155
stop : Optional [Union [str , List [str ]]] = None
156
+ stopping_criteria : Optional [llama .StoppingCriteriaList ] = None
153
157
154
158
155
159
class ChatFormatter (Protocol ):
@@ -173,12 +177,14 @@ def __init__(
173
177
eos_token : str ,
174
178
bos_token : str ,
175
179
add_generation_prompt : bool = True ,
180
+ stop_token_ids : Optional [List [int ]] = None ,
176
181
):
177
182
"""A chat formatter that uses jinja2 templates to format the prompt."""
178
183
self .template = template
179
184
self .eos_token = eos_token
180
185
self .bos_token = bos_token
181
186
self .add_generation_prompt = add_generation_prompt
187
+ self .stop_token_ids = set (stop_token_ids ) if stop_token_ids is not None else None
182
188
183
189
self ._environment = jinja2 .Environment (
184
190
loader = jinja2 .BaseLoader (),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211
217
tool_choice = tool_choice ,
212
218
)
213
219
214
- return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ])
220
+ stopping_criteria = None
221
+ if self .stop_token_ids is not None :
222
+ def stop_on_last_token (
223
+ tokens : npt .NDArray [np .intc ],
224
+ logits : npt .NDArray [np .single ]
225
+ ) -> bool :
226
+ return tokens [- 1 ] in self .stop_token_ids
227
+ stopping_criteria = llama .StoppingCriteriaList ([stop_on_last_token ])
228
+
229
+ return ChatFormatterResponse (prompt = prompt , stop = [self .eos_token ], stopping_criteria = stopping_criteria )
215
230
216
231
def to_chat_handler (self ) -> LlamaChatCompletionHandler :
217
232
return chat_formatter_to_chat_completion_handler (self )
@@ -533,6 +548,10 @@ def chat_completion_handler(
533
548
rstop = result .stop if isinstance (result .stop , list ) else [result .stop ]
534
549
stop = stop + rstop
535
550
551
+ stopping_criteria = None
552
+ if result .stopping_criteria is not None :
553
+ stopping_criteria = result .stopping_criteria
554
+
536
555
if response_format is not None and response_format ["type" ] == "json_object" :
537
556
grammar = _grammar_for_response_format (response_format , verbose = llama .verbose )
538
557
@@ -598,6 +617,7 @@ def chat_completion_handler(
598
617
mirostat_eta = mirostat_eta ,
599
618
model = model ,
600
619
logits_processor = logits_processor ,
620
+ stopping_criteria = stopping_criteria ,
601
621
grammar = grammar ,
602
622
logit_bias = logit_bias ,
603
623
)
0 commit comments