Skip to content

Commit 8ba2278

Browse files
committed
always use seq_id=0 for generation; provide strftime_now to templates
1 parent e1af05f commit 8ba2278

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

llama_cpp/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def eval(self, tokens: Sequence[int]):
637637
Args:
638638
tokens: The list of tokens to evaluate.
639639
"""
640-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
640+
self._ctx.kv_cache_seq_rm(0, self.n_tokens, -1)
641641
for i in range(0, len(tokens), self.n_batch):
642642
batch = tokens[i : min(len(tokens), i + self.n_batch)]
643643
n_past = self.n_tokens
@@ -945,7 +945,7 @@ def generate(
945945

946946
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
947947
self.n_tokens = sample_idx
948-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
948+
self._ctx.kv_cache_seq_rm(0, self.n_tokens, -1)
949949
break
950950

951951
if self.draft_model is not None:

llama_cpp/llama_chat_format.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import string
1010

11+
from datetime import datetime
1112
from contextlib import ExitStack
1213
from typing import (
1314
Any,
@@ -214,6 +215,10 @@ def __init__(
214215
lstrip_blocks=True,
215216
).from_string(self.template)
216217

218+
@staticmethod
219+
def strftime_now(f: str) -> str:
220+
return datetime.now().strftime(f)
221+
217222
def __call__(
218223
self,
219224
*,
@@ -237,6 +242,7 @@ def raise_exception(message: str):
237242
function_call=function_call,
238243
tools=tools,
239244
tool_choice=tool_choice,
245+
strftime_now=self.strftime_now,
240246
)
241247

242248
stopping_criteria = None
@@ -2752,10 +2758,10 @@ def _create_bitmap_from_bytes(self, image_bytes: bytes):
27522758
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
27532759
len(image_bytes)
27542760
)
2755-
2761+
27562762
if bitmap is None:
27572763
raise ValueError("Failed to create bitmap from image bytes")
2758-
2764+
27592765
return bitmap
27602766

27612767
def __call__(
@@ -2814,18 +2820,18 @@ def __call__(
28142820
trim_blocks=True,
28152821
lstrip_blocks=True,
28162822
).from_string(self.CHAT_FORMAT)
2817-
2823+
28182824
# Get the default media marker
28192825
media_marker = self._mtmd_cpp.mtmd_default_marker().decode('utf-8')
2820-
2826+
28212827
# Replace image URLs with media markers in the template
28222828
text = template.render(
28232829
messages=messages,
28242830
add_generation_prompt=True,
28252831
eos_token=llama.detokenize([llama.token_eos()]),
28262832
bos_token=llama.detokenize([llama.token_bos()]),
28272833
)
2828-
2834+
28292835
# Replace image URLs in text with media markers
28302836
for image_url in image_urls:
28312837
text = text.replace(image_url, media_marker)
@@ -2875,40 +2881,40 @@ def __call__(
28752881
# Process each chunk
28762882
n_past = llama_cpp.llama_pos(0)
28772883
n_chunks = self._mtmd_cpp.mtmd_input_chunks_size(chunks)
2878-
2884+
28792885
for i in range(n_chunks):
28802886
chunk = self._mtmd_cpp.mtmd_input_chunks_get(chunks, i)
28812887
if chunk is None:
28822888
continue
28832889

28842890
chunk_type = self._mtmd_cpp.mtmd_input_chunk_get_type(chunk)
2885-
2891+
28862892
if chunk_type == self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_TEXT:
28872893
# Handle text chunk
28882894
n_tokens_out = ctypes.c_size_t()
28892895
tokens_ptr = self._mtmd_cpp.mtmd_input_chunk_get_tokens_text(
28902896
chunk, ctypes.byref(n_tokens_out)
28912897
)
2892-
2898+
28932899
if tokens_ptr and n_tokens_out.value > 0:
28942900
# Convert ctypes array to Python list
28952901
tokens = [tokens_ptr[j] for j in range(n_tokens_out.value)]
2896-
2902+
28972903
if llama.n_tokens + len(tokens) > llama.n_ctx():
28982904
raise ValueError(
28992905
f"Prompt exceeds n_ctx: {llama.n_tokens + len(tokens)} > {llama.n_ctx()}"
29002906
)
29012907
llama.eval(tokens)
2902-
2908+
29032909
elif chunk_type in [self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_IMAGE, self._mtmd_cpp.MTMD_INPUT_CHUNK_TYPE_AUDIO]:
29042910
# Handle image/audio chunk using helper
29052911
chunk_n_tokens = self._mtmd_cpp.mtmd_input_chunk_get_n_tokens(chunk)
2906-
2912+
29072913
if llama.n_tokens + chunk_n_tokens > llama.n_ctx():
29082914
raise ValueError(
29092915
f"Prompt exceeds n_ctx: {llama.n_tokens + chunk_n_tokens} > {llama.n_ctx()}"
29102916
)
2911-
2917+
29122918
new_n_past = llama_cpp.llama_pos(0)
29132919
result = self._mtmd_cpp.mtmd_helper_eval_chunk_single(
29142920
self.mtmd_ctx,
@@ -2920,10 +2926,10 @@ def __call__(
29202926
False, # logits_last
29212927
ctypes.byref(new_n_past)
29222928
)
2923-
2929+
29242930
if result != 0:
29252931
raise ValueError(f"Failed to evaluate chunk: error code {result}")
2926-
2932+
29272933
# Update llama's token count
29282934
llama.n_tokens = new_n_past.value
29292935

@@ -3013,7 +3019,7 @@ def __call__(
30133019
grammar=grammar,
30143020
logit_bias=logit_bias,
30153021
)
3016-
3022+
30173023
if tool is not None:
30183024
tool_name = tool["function"]["name"]
30193025
return _convert_completion_to_chat_function(

0 commit comments

Comments
 (0)