Skip to content

Commit 19ba9d3

Browse files
committed
Use numpy arrays for logits_processors and stopping_criteria. Closes ggml-org#491
1 parent 5eab1db commit 19ba9d3

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

llama_cpp/llama.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import numpy.typing as npt
2929

30+
3031
class BaseLlamaCache(ABC):
3132
"""Base cache class for a llama.cpp model."""
3233

@@ -179,21 +180,27 @@ def __init__(
179180
self.llama_state_size = llama_state_size
180181

181182

182-
LogitsProcessor = Callable[[List[int], List[float]], List[float]]
183+
LogitsProcessor = Callable[
184+
[npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single]
185+
]
183186

184187

185188
class LogitsProcessorList(List[LogitsProcessor]):
186-
def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]:
189+
def __call__(
190+
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
191+
) -> npt.NDArray[np.single]:
187192
for processor in self:
188193
scores = processor(input_ids, scores)
189194
return scores
190195

191196

192-
StoppingCriteria = Callable[[List[int], List[float]], bool]
197+
StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool]
193198

194199

195200
class StoppingCriteriaList(List[StoppingCriteria]):
196-
def __call__(self, input_ids: List[int], logits: List[float]) -> bool:
201+
def __call__(
202+
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
203+
) -> bool:
197204
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
198205

199206

@@ -274,9 +281,11 @@ def __init__(
274281
self._c_tensor_split = None
275282

276283
if self.tensor_split is not None:
277-
#Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
284+
# Type conversion and expand the list to the length of LLAMA_MAX_DEVICES
278285
FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
279-
self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd
286+
self._c_tensor_split = FloatArray(
287+
*tensor_split
288+
) # keep a reference to the array so it is not gc'd
280289
self.params.tensor_split = self._c_tensor_split
281290

282291
self.params.rope_freq_base = rope_freq_base
@@ -503,11 +512,7 @@ def _sample(
503512
logits: npt.NDArray[np.single] = self._scores[-1, :]
504513

505514
if logits_processor is not None:
506-
logits = np.array(
507-
logits_processor(self._input_ids.tolist(), logits.tolist()),
508-
dtype=np.single,
509-
)
510-
self._scores[-1, :] = logits
515+
logits[:] = logits_processor(self._input_ids, logits)
511516

512517
nl_logit = logits[self._token_nl]
513518
candidates = self._candidates
@@ -725,7 +730,7 @@ def generate(
725730
logits_processor=logits_processor,
726731
)
727732
if stopping_criteria is not None and stopping_criteria(
728-
self._input_ids.tolist(), self._scores[-1, :].tolist()
733+
self._input_ids, self._scores[-1, :]
729734
):
730735
return
731736
tokens_or_none = yield token
@@ -1014,7 +1019,7 @@ def _create_completion(
10141019
break
10151020

10161021
if stopping_criteria is not None and stopping_criteria(
1017-
self._input_ids.tolist(), self._scores[-1, :].tolist()
1022+
self._input_ids, self._scores[-1, :]
10181023
):
10191024
text = self.detokenize(completion_tokens)
10201025
finish_reason = "stop"

llama_cpp/server/app.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from pydantic_settings import BaseSettings
1717
from sse_starlette.sse import EventSourceResponse
1818

19+
import numpy as np
20+
import numpy.typing as npt
21+
1922

2023
class Settings(BaseSettings):
2124
model: str = Field(
@@ -336,9 +339,9 @@ def make_logit_bias_processor(
336339
to_bias[input_id] = score
337340

338341
def logit_bias_processor(
339-
input_ids: List[int],
340-
scores: List[float],
341-
) -> List[float]:
342+
input_ids: npt.NDArray[np.intc],
343+
scores: npt.NDArray[np.single],
344+
) -> npt.NDArray[np.single]:
342345
new_scores = [None] * len(scores)
343346
for input_id, score in enumerate(scores):
344347
new_scores[input_id] = score + to_bias.get(input_id, 0.0)

0 commit comments

Comments
 (0)