Skip to content

Commit 30fc0f3

Browse files
committed
Extract generate method
1 parent 1c823f6 commit 30fc0f3

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

llama_cpp/llama.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ def _sample(self, last_n_tokens, top_p, top_k, temp, repeat_penalty):
128128
repeat_penalty=repeat_penalty,
129129
)
130130

131+
def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty):
132+
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
133+
last_n_tokens.extend(past_tokens)
134+
for i in range(max_tokens):
135+
token = self._sample(
136+
last_n_tokens,
137+
top_p=top_p,
138+
top_k=top_k,
139+
temp=temp,
140+
repeat_penalty=repeat_penalty
141+
)
142+
yield token
143+
self._eval([token], len(past_tokens) + i)
144+
131145
def __call__(
132146
self,
133147
prompt: str,
@@ -162,8 +176,9 @@ def __call__(
162176
Returns:
163177
Response object containing the generated text.
164178
"""
179+
completion_id = f"cmpl-{str(uuid.uuid4())}"
180+
created= int(time.time())
165181
text = b""
166-
finish_reason = "length"
167182
completion_tokens = []
168183
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
169184

@@ -182,14 +197,8 @@ def __call__(
182197
if stop is not None:
183198
stop = [s.encode("utf-8") for s in stop]
184199

185-
for i in range(max_tokens):
186-
token = self._sample(
187-
last_n_tokens,
188-
top_p=top_p,
189-
top_k=top_k,
190-
temp=temperature,
191-
repeat_penalty=repeat_penalty
192-
)
200+
finish_reason = None
201+
for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
193202
if token == llama_cpp.llama_token_eos():
194203
finish_reason = "stop"
195204
break
@@ -204,7 +213,8 @@ def __call__(
204213
finish_reason = "stop"
205214
break
206215

207-
self._eval([token], len(prompt_tokens) + len(completion_tokens))
216+
if finish_reason is None:
217+
finish_reason = "length"
208218

209219
text = text.decode("utf-8")
210220

@@ -220,9 +230,9 @@ def __call__(
220230
)[:logprobs]
221231

222232
return {
223-
"id": f"cmpl-{str(uuid.uuid4())}", # Likely to change
233+
"id": completion_id,
224234
"object": "text_completion",
225-
"created": int(time.time()),
235+
"created": created,
226236
"model": self.model_path,
227237
"choices": [
228238
{

0 commit comments

Comments
 (0)