Skip to content

Commit cc81afe

Browse files
committed
feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct
1 parent d17c188 commit cc81afe

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,10 @@ def __init__(
426426
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
427427

428428
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
429-
template=template, eos_token=eos_token, bos_token=bos_token
429+
template=template,
430+
eos_token=eos_token,
431+
bos_token=bos_token,
432+
stop_token_ids=[eos_token_id],
430433
).to_chat_handler()
431434

432435
if self.chat_format is None and self.chat_handler is None:

llama_cpp/llama_chat_format.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
import jinja2
1212

13+
import numpy as np
14+
import numpy.typing as npt
15+
1316
import llama_cpp.llama as llama
1417
import llama_cpp.llama_types as llama_types
1518
import llama_cpp.llama_grammar as llama_grammar
@@ -150,6 +153,7 @@ class ChatFormatterResponse:
150153

151154
prompt: str
152155
stop: Optional[Union[str, List[str]]] = None
156+
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
153157

154158

155159
class ChatFormatter(Protocol):
@@ -173,12 +177,14 @@ def __init__(
173177
eos_token: str,
174178
bos_token: str,
175179
add_generation_prompt: bool = True,
180+
stop_token_ids: Optional[List[int]] = None,
176181
):
177182
"""A chat formatter that uses jinja2 templates to format the prompt."""
178183
self.template = template
179184
self.eos_token = eos_token
180185
self.bos_token = bos_token
181186
self.add_generation_prompt = add_generation_prompt
187+
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
182188

183189
self._environment = jinja2.Environment(
184190
loader=jinja2.BaseLoader(),
@@ -211,7 +217,16 @@ def raise_exception(message: str):
211217
tool_choice=tool_choice,
212218
)
213219

214-
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
220+
stopping_criteria = None
221+
if self.stop_token_ids is not None:
222+
def stop_on_last_token(
223+
tokens: npt.NDArray[np.intc],
224+
logits: npt.NDArray[np.single]
225+
) -> bool:
226+
return tokens[-1] in self.stop_token_ids
227+
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
228+
229+
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
215230

216231
def to_chat_handler(self) -> LlamaChatCompletionHandler:
217232
return chat_formatter_to_chat_completion_handler(self)
@@ -533,6 +548,10 @@ def chat_completion_handler(
533548
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
534549
stop = stop + rstop
535550

551+
stopping_criteria = None
552+
if result.stopping_criteria is not None:
553+
stopping_criteria = result.stopping_criteria
554+
536555
if response_format is not None and response_format["type"] == "json_object":
537556
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
538557

@@ -598,6 +617,7 @@ def chat_completion_handler(
598617
mirostat_eta=mirostat_eta,
599618
model=model,
600619
logits_processor=logits_processor,
620+
stopping_criteria=stopping_criteria,
601621
grammar=grammar,
602622
logit_bias=logit_bias,
603623
)

0 commit comments

Comments
 (0)