@@ -128,6 +128,20 @@ def _sample(self, last_n_tokens, top_p, top_k, temp, repeat_penalty):
128
128
repeat_penalty = repeat_penalty ,
129
129
)
130
130
131
+ def _generate (self , past_tokens , max_tokens , top_p , top_k , temp , repeat_penalty ):
132
+ last_n_tokens = deque ([0 ] * self .last_n , maxlen = self .last_n )
133
+ last_n_tokens .extend (past_tokens )
134
+ for i in range (max_tokens ):
135
+ token = self ._sample (
136
+ last_n_tokens ,
137
+ top_p = top_p ,
138
+ top_k = top_k ,
139
+ temp = temp ,
140
+ repeat_penalty = repeat_penalty
141
+ )
142
+ yield token
143
+ self ._eval ([token ], len (past_tokens ) + i )
144
+
131
145
def __call__ (
132
146
self ,
133
147
prompt : str ,
@@ -162,8 +176,9 @@ def __call__(
162
176
Returns:
163
177
Response object containing the generated text.
164
178
"""
179
+ completion_id = f"cmpl-{ str (uuid .uuid4 ())} "
180
+ created = int (time .time ())
165
181
text = b""
166
- finish_reason = "length"
167
182
completion_tokens = []
168
183
last_n_tokens = deque ([0 ] * self .last_n , maxlen = self .last_n )
169
184
@@ -182,14 +197,8 @@ def __call__(
182
197
if stop is not None :
183
198
stop = [s .encode ("utf-8" ) for s in stop ]
184
199
185
- for i in range (max_tokens ):
186
- token = self ._sample (
187
- last_n_tokens ,
188
- top_p = top_p ,
189
- top_k = top_k ,
190
- temp = temperature ,
191
- repeat_penalty = repeat_penalty
192
- )
200
+ finish_reason = None
201
+ for token in self ._generate (prompt_tokens , max_tokens , top_p , top_k , temperature , repeat_penalty ):
193
202
if token == llama_cpp .llama_token_eos ():
194
203
finish_reason = "stop"
195
204
break
@@ -204,7 +213,8 @@ def __call__(
204
213
finish_reason = "stop"
205
214
break
206
215
207
- self ._eval ([token ], len (prompt_tokens ) + len (completion_tokens ))
216
+ if finish_reason is None :
217
+ finish_reason = "length"
208
218
209
219
text = text .decode ("utf-8" )
210
220
@@ -220,9 +230,9 @@ def __call__(
220
230
)[:logprobs ]
221
231
222
232
return {
223
- "id" : f"cmpl- { str ( uuid . uuid4 ()) } " , # Likely to change
233
+ "id" : completion_id ,
224
234
"object" : "text_completion" ,
225
- "created" : int ( time . time ()) ,
235
+ "created" : created ,
226
236
"model" : self .model_path ,
227
237
"choices" : [
228
238
{
0 commit comments