Skip to content

Commit 3dbb3fd

Browse files
committed
Add support for stream parameter. Closes #1
1 parent 30fc0f3 commit 3dbb3fd

File tree

2 files changed

+129
-33
lines changed

2 files changed

+129
-33
lines changed

examples/high_level_api_streaming.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import json
2+
import argparse
3+
4+
from llama_cpp import Llama
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("-m", "--model", type=str, default=".//models/...")
8+
args = parser.parse_args()
9+
10+
llm = Llama(model_path=args.model)
11+
12+
stream = llm(
13+
"Question: What are the names of the planets in the solar system? Answer: ",
14+
max_tokens=48,
15+
stop=["Q:", "\n"],
16+
stream=True,
17+
)
18+
19+
for output in stream:
20+
print(json.dumps(output, indent=2))

llama_cpp/llama.py

Lines changed: 109 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def tokenize(self, text: bytes) -> List[int]:
8888
True,
8989
)
9090
if n_tokens < 0:
91-
raise RuntimeError(f"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}")
91+
raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
9292
return list(tokens[:n_tokens])
9393

9494
def detokenize(self, tokens: List[int]) -> bytes:
@@ -105,7 +105,6 @@ def detokenize(self, tokens: List[int]) -> bytes:
105105
output += llama_cpp.llama_token_to_str(self.ctx, token)
106106
return output
107107

108-
109108
def _eval(self, tokens: List[int], n_past):
110109
rc = llama_cpp.llama_eval(
111110
self.ctx,
@@ -137,12 +136,12 @@ def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty)
137136
top_p=top_p,
138137
top_k=top_k,
139138
temp=temp,
140-
repeat_penalty=repeat_penalty
139+
repeat_penalty=repeat_penalty,
141140
)
142141
yield token
143142
self._eval([token], len(past_tokens) + i)
144143

145-
def __call__(
144+
def _call(
146145
self,
147146
prompt: str,
148147
suffix: Optional[str] = None,
@@ -154,34 +153,11 @@ def __call__(
154153
stop: List[str] = [],
155154
repeat_penalty: float = 1.1,
156155
top_k: int = 40,
156+
stream: bool = False,
157157
):
158-
"""Generate text from a prompt.
159-
160-
Args:
161-
prompt: The prompt to generate text from.
162-
suffix: A suffix to append to the generated text. If None, no suffix is appended.
163-
max_tokens: The maximum number of tokens to generate.
164-
temperature: The temperature to use for sampling.
165-
top_p: The top-p value to use for sampling.
166-
logprobs: The number of logprobs to return. If None, no logprobs are returned.
167-
echo: Whether to echo the prompt.
168-
stop: A list of strings to stop generation when encountered.
169-
repeat_penalty: The penalty to apply to repeated tokens.
170-
top_k: The top-k value to use for sampling.
171-
172-
Raises:
173-
ValueError: If the requested tokens exceed the context window.
174-
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
175-
176-
Returns:
177-
Response object containing the generated text.
178-
"""
179158
completion_id = f"cmpl-{str(uuid.uuid4())}"
180-
created= int(time.time())
181-
text = b""
159+
created = int(time.time())
182160
completion_tokens = []
183-
last_n_tokens = deque([0] * self.last_n, maxlen=self.last_n)
184-
185161
prompt_tokens = self.tokenize(prompt.encode("utf-8"))
186162

187163
if len(prompt_tokens) + max_tokens > llama_cpp.llama_n_ctx(self.ctx):
@@ -198,24 +174,71 @@ def __call__(
198174
stop = [s.encode("utf-8") for s in stop]
199175

200176
finish_reason = None
201-
for token in self._generate(prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty):
177+
for token in self._generate(
178+
prompt_tokens, max_tokens, top_p, top_k, temperature, repeat_penalty
179+
):
202180
if token == llama_cpp.llama_token_eos():
203181
finish_reason = "stop"
204182
break
205-
text += self.detokenize([token])
206-
last_n_tokens.append(token)
207183
completion_tokens.append(token)
208184

185+
text = self.detokenize(completion_tokens)
209186
any_stop = [s for s in stop if s in text]
210187
if len(any_stop) > 0:
211188
first_stop = any_stop[0]
212189
text = text[: text.index(first_stop)]
213190
finish_reason = "stop"
214191
break
215192

193+
if stream:
194+
start = len(self.detokenize(completion_tokens[:-1]))
195+
longest = 0
196+
for s in stop:
197+
for i in range(len(s), 0, -1):
198+
if s[-i:] == text[-i:]:
199+
if i > longest:
200+
longest = i
201+
break
202+
yield {
203+
"id": completion_id,
204+
"object": "text_completion",
205+
"created": created,
206+
"model": self.model_path,
207+
"choices": [
208+
{
209+
"text": text[start : len(text) - longest].decode("utf-8"),
210+
"index": 0,
211+
"logprobs": None,
212+
"finish_reason": None,
213+
}
214+
],
215+
}
216+
216217
if finish_reason is None:
217218
finish_reason = "length"
218219

220+
if stream:
221+
if finish_reason == "stop":
222+
start = len(self.detokenize(completion_tokens[:-1]))
223+
text = text[start:].decode("utf-8")
224+
else:
225+
text = ""
226+
yield {
227+
"id": completion_id,
228+
"object": "text_completion",
229+
"created": created,
230+
"model": self.model_path,
231+
"choices": [
232+
{
233+
"text": text,
234+
"index": 0,
235+
"logprobs": None,
236+
"finish_reason": finish_reason,
237+
}
238+
],
239+
}
240+
return
241+
219242
text = text.decode("utf-8")
220243

221244
if echo:
@@ -229,7 +252,7 @@ def __call__(
229252
self.ctx,
230253
)[:logprobs]
231254

232-
return {
255+
yield {
233256
"id": completion_id,
234257
"object": "text_completion",
235258
"created": created,
@@ -249,5 +272,58 @@ def __call__(
249272
},
250273
}
251274

275+
def __call__(
276+
self,
277+
prompt: str,
278+
suffix: Optional[str] = None,
279+
max_tokens: int = 16,
280+
temperature: float = 0.8,
281+
top_p: float = 0.95,
282+
logprobs: Optional[int] = None,
283+
echo: bool = False,
284+
stop: List[str] = [],
285+
repeat_penalty: float = 1.1,
286+
top_k: int = 40,
287+
stream: bool = False,
288+
):
289+
"""Generate text from a prompt.
290+
291+
Args:
292+
prompt: The prompt to generate text from.
293+
suffix: A suffix to append to the generated text. If None, no suffix is appended.
294+
max_tokens: The maximum number of tokens to generate.
295+
temperature: The temperature to use for sampling.
296+
top_p: The top-p value to use for sampling.
297+
logprobs: The number of logprobs to return. If None, no logprobs are returned.
298+
echo: Whether to echo the prompt.
299+
stop: A list of strings to stop generation when encountered.
300+
repeat_penalty: The penalty to apply to repeated tokens.
301+
top_k: The top-k value to use for sampling.
302+
stream: Whether to stream the results.
303+
304+
Raises:
305+
ValueError: If the requested tokens exceed the context window.
306+
RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
307+
308+
Returns:
309+
Response object containing the generated text.
310+
"""
311+
call = self._call(
312+
prompt=prompt,
313+
suffix=suffix,
314+
max_tokens=max_tokens,
315+
temperature=temperature,
316+
top_p=top_p,
317+
logprobs=logprobs,
318+
echo=echo,
319+
stop=stop,
320+
repeat_penalty=repeat_penalty,
321+
top_k=top_k,
322+
stream=stream,
323+
)
324+
if stream:
325+
return call
326+
return next(call)
327+
252328
def __del__(self):
253329
llama_cpp.llama_free(self.ctx)

0 commit comments

Comments
 (0)