@@ -88,7 +88,7 @@ def tokenize(self, text: bytes) -> List[int]:
88
88
True ,
89
89
)
90
90
if n_tokens < 0 :
91
- raise RuntimeError (f" Failed to tokenize: text=\ "{ text } \ " n_tokens={ n_tokens } " )
91
+ raise RuntimeError (f' Failed to tokenize: text="{ text } " n_tokens={ n_tokens } ' )
92
92
return list (tokens [:n_tokens ])
93
93
94
94
def detokenize (self , tokens : List [int ]) -> bytes :
@@ -105,7 +105,6 @@ def detokenize(self, tokens: List[int]) -> bytes:
105
105
output += llama_cpp .llama_token_to_str (self .ctx , token )
106
106
return output
107
107
108
-
109
108
def _eval (self , tokens : List [int ], n_past ):
110
109
rc = llama_cpp .llama_eval (
111
110
self .ctx ,
@@ -137,12 +136,12 @@ def _generate(self, past_tokens, max_tokens, top_p, top_k, temp, repeat_penalty)
137
136
top_p = top_p ,
138
137
top_k = top_k ,
139
138
temp = temp ,
140
- repeat_penalty = repeat_penalty
139
+ repeat_penalty = repeat_penalty ,
141
140
)
142
141
yield token
143
142
self ._eval ([token ], len (past_tokens ) + i )
144
143
145
- def __call__ (
144
+ def _call (
146
145
self ,
147
146
prompt : str ,
148
147
suffix : Optional [str ] = None ,
@@ -154,34 +153,11 @@ def __call__(
154
153
stop : List [str ] = [],
155
154
repeat_penalty : float = 1.1 ,
156
155
top_k : int = 40 ,
156
+ stream : bool = False ,
157
157
):
158
- """Generate text from a prompt.
159
-
160
- Args:
161
- prompt: The prompt to generate text from.
162
- suffix: A suffix to append to the generated text. If None, no suffix is appended.
163
- max_tokens: The maximum number of tokens to generate.
164
- temperature: The temperature to use for sampling.
165
- top_p: The top-p value to use for sampling.
166
- logprobs: The number of logprobs to return. If None, no logprobs are returned.
167
- echo: Whether to echo the prompt.
168
- stop: A list of strings to stop generation when encountered.
169
- repeat_penalty: The penalty to apply to repeated tokens.
170
- top_k: The top-k value to use for sampling.
171
-
172
- Raises:
173
- ValueError: If the requested tokens exceed the context window.
174
- RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
175
-
176
- Returns:
177
- Response object containing the generated text.
178
- """
179
158
completion_id = f"cmpl-{ str (uuid .uuid4 ())} "
180
- created = int (time .time ())
181
- text = b""
159
+ created = int (time .time ())
182
160
completion_tokens = []
183
- last_n_tokens = deque ([0 ] * self .last_n , maxlen = self .last_n )
184
-
185
161
prompt_tokens = self .tokenize (prompt .encode ("utf-8" ))
186
162
187
163
if len (prompt_tokens ) + max_tokens > llama_cpp .llama_n_ctx (self .ctx ):
@@ -198,24 +174,71 @@ def __call__(
198
174
stop = [s .encode ("utf-8" ) for s in stop ]
199
175
200
176
finish_reason = None
201
- for token in self ._generate (prompt_tokens , max_tokens , top_p , top_k , temperature , repeat_penalty ):
177
+ for token in self ._generate (
178
+ prompt_tokens , max_tokens , top_p , top_k , temperature , repeat_penalty
179
+ ):
202
180
if token == llama_cpp .llama_token_eos ():
203
181
finish_reason = "stop"
204
182
break
205
- text += self .detokenize ([token ])
206
- last_n_tokens .append (token )
207
183
completion_tokens .append (token )
208
184
185
+ text = self .detokenize (completion_tokens )
209
186
any_stop = [s for s in stop if s in text ]
210
187
if len (any_stop ) > 0 :
211
188
first_stop = any_stop [0 ]
212
189
text = text [: text .index (first_stop )]
213
190
finish_reason = "stop"
214
191
break
215
192
193
+ if stream :
194
+ start = len (self .detokenize (completion_tokens [:- 1 ]))
195
+ longest = 0
196
+ for s in stop :
197
+ for i in range (len (s ), 0 , - 1 ):
198
+ if s [- i :] == text [- i :]:
199
+ if i > longest :
200
+ longest = i
201
+ break
202
+ yield {
203
+ "id" : completion_id ,
204
+ "object" : "text_completion" ,
205
+ "created" : created ,
206
+ "model" : self .model_path ,
207
+ "choices" : [
208
+ {
209
+ "text" : text [start : len (text ) - longest ].decode ("utf-8" ),
210
+ "index" : 0 ,
211
+ "logprobs" : None ,
212
+ "finish_reason" : None ,
213
+ }
214
+ ],
215
+ }
216
+
216
217
if finish_reason is None :
217
218
finish_reason = "length"
218
219
220
+ if stream :
221
+ if finish_reason == "stop" :
222
+ start = len (self .detokenize (completion_tokens [:- 1 ]))
223
+ text = text [start :].decode ("utf-8" )
224
+ else :
225
+ text = ""
226
+ yield {
227
+ "id" : completion_id ,
228
+ "object" : "text_completion" ,
229
+ "created" : created ,
230
+ "model" : self .model_path ,
231
+ "choices" : [
232
+ {
233
+ "text" : text ,
234
+ "index" : 0 ,
235
+ "logprobs" : None ,
236
+ "finish_reason" : finish_reason ,
237
+ }
238
+ ],
239
+ }
240
+ return
241
+
219
242
text = text .decode ("utf-8" )
220
243
221
244
if echo :
@@ -229,7 +252,7 @@ def __call__(
229
252
self .ctx ,
230
253
)[:logprobs ]
231
254
232
- return {
255
+ yield {
233
256
"id" : completion_id ,
234
257
"object" : "text_completion" ,
235
258
"created" : created ,
@@ -249,5 +272,58 @@ def __call__(
249
272
},
250
273
}
251
274
275
+ def __call__ (
276
+ self ,
277
+ prompt : str ,
278
+ suffix : Optional [str ] = None ,
279
+ max_tokens : int = 16 ,
280
+ temperature : float = 0.8 ,
281
+ top_p : float = 0.95 ,
282
+ logprobs : Optional [int ] = None ,
283
+ echo : bool = False ,
284
+ stop : List [str ] = [],
285
+ repeat_penalty : float = 1.1 ,
286
+ top_k : int = 40 ,
287
+ stream : bool = False ,
288
+ ):
289
+ """Generate text from a prompt.
290
+
291
+ Args:
292
+ prompt: The prompt to generate text from.
293
+ suffix: A suffix to append to the generated text. If None, no suffix is appended.
294
+ max_tokens: The maximum number of tokens to generate.
295
+ temperature: The temperature to use for sampling.
296
+ top_p: The top-p value to use for sampling.
297
+ logprobs: The number of logprobs to return. If None, no logprobs are returned.
298
+ echo: Whether to echo the prompt.
299
+ stop: A list of strings to stop generation when encountered.
300
+ repeat_penalty: The penalty to apply to repeated tokens.
301
+ top_k: The top-k value to use for sampling.
302
+ stream: Whether to stream the results.
303
+
304
+ Raises:
305
+ ValueError: If the requested tokens exceed the context window.
306
+ RuntimeError: If the prompt fails to tokenize or the model fails to evaluate the prompt.
307
+
308
+ Returns:
309
+ Response object containing the generated text.
310
+ """
311
+ call = self ._call (
312
+ prompt = prompt ,
313
+ suffix = suffix ,
314
+ max_tokens = max_tokens ,
315
+ temperature = temperature ,
316
+ top_p = top_p ,
317
+ logprobs = logprobs ,
318
+ echo = echo ,
319
+ stop = stop ,
320
+ repeat_penalty = repeat_penalty ,
321
+ top_k = top_k ,
322
+ stream = stream ,
323
+ )
324
+ if stream :
325
+ return call
326
+ return next (call )
327
+
252
328
def __del__ (self ):
253
329
llama_cpp .llama_free (self .ctx )
0 commit comments