|
27 | 27 | import numpy as np
|
28 | 28 | import numpy.typing as npt
|
29 | 29 |
|
| 30 | + |
30 | 31 | class BaseLlamaCache(ABC):
|
31 | 32 | """Base cache class for a llama.cpp model."""
|
32 | 33 |
|
@@ -179,21 +180,27 @@ def __init__(
|
179 | 180 | self.llama_state_size = llama_state_size
|
180 | 181 |
|
181 | 182 |
|
182 |
| -LogitsProcessor = Callable[[List[int], List[float]], List[float]] |
| 183 | +LogitsProcessor = Callable[ |
| 184 | + [npt.NDArray[np.intc], npt.NDArray[np.single]], npt.NDArray[np.single] |
| 185 | +] |
183 | 186 |
|
184 | 187 |
|
185 | 188 | class LogitsProcessorList(List[LogitsProcessor]):
|
186 |
| - def __call__(self, input_ids: List[int], scores: List[float]) -> List[float]: |
| 189 | + def __call__( |
| 190 | + self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single] |
| 191 | + ) -> npt.NDArray[np.single]: |
187 | 192 | for processor in self:
|
188 | 193 | scores = processor(input_ids, scores)
|
189 | 194 | return scores
|
190 | 195 |
|
191 | 196 |
|
192 |
| -StoppingCriteria = Callable[[List[int], List[float]], bool] |
| 197 | +StoppingCriteria = Callable[[npt.NDArray[np.intc], npt.NDArray[np.single]], bool] |
193 | 198 |
|
194 | 199 |
|
195 | 200 | class StoppingCriteriaList(List[StoppingCriteria]):
|
196 |
| - def __call__(self, input_ids: List[int], logits: List[float]) -> bool: |
| 201 | + def __call__( |
| 202 | + self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single] |
| 203 | + ) -> bool: |
197 | 204 | return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])
|
198 | 205 |
|
199 | 206 |
|
@@ -274,9 +281,11 @@ def __init__(
|
274 | 281 | self._c_tensor_split = None
|
275 | 282 |
|
276 | 283 | if self.tensor_split is not None:
|
277 |
| - #Type conversion and expand the list to the length of LLAMA_MAX_DEVICES |
| 284 | + # Type conversion and expand the list to the length of LLAMA_MAX_DEVICES |
278 | 285 | FloatArray = ctypes.c_float * llama_cpp.LLAMA_MAX_DEVICES.value
|
279 |
| - self._c_tensor_split = FloatArray(*tensor_split) # keep a reference to the array so it is not gc'd |
| 286 | + self._c_tensor_split = FloatArray( |
| 287 | + *tensor_split |
| 288 | + ) # keep a reference to the array so it is not gc'd |
280 | 289 | self.params.tensor_split = self._c_tensor_split
|
281 | 290 |
|
282 | 291 | self.params.rope_freq_base = rope_freq_base
|
@@ -503,11 +512,7 @@ def _sample(
|
503 | 512 | logits: npt.NDArray[np.single] = self._scores[-1, :]
|
504 | 513 |
|
505 | 514 | if logits_processor is not None:
|
506 |
| - logits = np.array( |
507 |
| - logits_processor(self._input_ids.tolist(), logits.tolist()), |
508 |
| - dtype=np.single, |
509 |
| - ) |
510 |
| - self._scores[-1, :] = logits |
| 515 | + logits[:] = logits_processor(self._input_ids, logits) |
511 | 516 |
|
512 | 517 | nl_logit = logits[self._token_nl]
|
513 | 518 | candidates = self._candidates
|
@@ -725,7 +730,7 @@ def generate(
|
725 | 730 | logits_processor=logits_processor,
|
726 | 731 | )
|
727 | 732 | if stopping_criteria is not None and stopping_criteria(
|
728 |
| - self._input_ids.tolist(), self._scores[-1, :].tolist() |
| 733 | + self._input_ids, self._scores[-1, :] |
729 | 734 | ):
|
730 | 735 | return
|
731 | 736 | tokens_or_none = yield token
|
@@ -1014,7 +1019,7 @@ def _create_completion(
|
1014 | 1019 | break
|
1015 | 1020 |
|
1016 | 1021 | if stopping_criteria is not None and stopping_criteria(
|
1017 |
| - self._input_ids.tolist(), self._scores[-1, :].tolist() |
| 1022 | + self._input_ids, self._scores[-1, :] |
1018 | 1023 | ):
|
1019 | 1024 | text = self.detokenize(completion_tokens)
|
1020 | 1025 | finish_reason = "stop"
|
|
0 commit comments