@@ -44,8 +44,20 @@ enum console_state {
44
44
static console_state con_st = CONSOLE_STATE_DEFAULT;
45
45
static bool con_use_color = false ;
46
46
47
- void set_console_state (console_state new_st)
48
- {
47
+ void enable_console_colors () {
48
+ #if defined (_WIN32)
49
+ if (params.use_color ) {
50
+ // Enable ANSI colors on Windows 10+
51
+ unsigned long dwMode = 0 ;
52
+ void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
53
+ if (hConOut && hConOut != (void *)-1 && GetConsoleMode (hConOut, &dwMode) && !(dwMode & 0x4 )) {
54
+ SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
55
+ }
56
+ }
57
+ #endif
58
+ }
59
+
60
+ void set_console_state (console_state new_st) {
49
61
if (!con_use_color) return ;
50
62
// only emit color code if state changed
51
63
if (new_st != con_st) {
@@ -96,6 +108,14 @@ int main(int argc, char ** argv) {
96
108
return 0 ;
97
109
}
98
110
111
+ if (params.embedding ) {
112
+ printf (" \n ************\n " );
113
+ printf (" %s: please use the 'embedding' tool for embedding calculations\n " , __func__);
114
+ printf (" ************\n\n " );
115
+
116
+ return 0 ;
117
+ }
118
+
99
119
if (params.n_ctx > 2048 ) {
100
120
fprintf (stderr, " %s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
101
121
" expect poor results\n " , __func__, params.n_ctx );
@@ -165,8 +185,6 @@ int main(int argc, char ** argv) {
165
185
return 0 ;
166
186
}
167
187
168
- int n_past = 0 ;
169
-
170
188
// Add a space in front of the first character to match OG llama tokenizer behavior
171
189
params.prompt .insert (0 , 1 , ' ' );
172
190
@@ -175,7 +193,13 @@ int main(int argc, char ** argv) {
175
193
176
194
const int n_ctx = llama_n_ctx (ctx);
177
195
178
- params.n_predict = std::min (params.n_predict , n_ctx - (int ) embd_inp.size ());
196
+ if ((int ) embd_inp.size () > n_ctx - 4 ) {
197
+ fprintf (stderr, " %s: error: prompt is too long (%d tokens, max %d)\n " , __func__, (int ) embd_inp.size (), n_ctx - 4 );
198
+ return 1 ;
199
+ }
200
+
201
+ params.n_keep = std::min (params.n_keep , (int ) embd_inp.size ());
202
+ // params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size());
179
203
180
204
// prefix & suffix for instruct mode
181
205
const auto inp_pfx = ::llama_tokenize (ctx, " \n\n ### Instruction:\n\n " , true );
@@ -206,6 +230,13 @@ int main(int argc, char ** argv) {
206
230
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
207
231
fprintf (stderr, " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
208
232
}
233
+ if (params.n_keep > 0 ) {
234
+ fprintf (stderr, " %s: static prompt based on n_keep: '" , __func__);
235
+ for (int i = 0 ; i < params.n_keep ; i++) {
236
+ fprintf (stderr, " %s" , llama_token_to_str (ctx, embd_inp[i]));
237
+ }
238
+ fprintf (stderr, " '\n " );
239
+ }
209
240
fprintf (stderr, " \n " );
210
241
}
211
242
@@ -222,7 +253,7 @@ int main(int argc, char ** argv) {
222
253
223
254
fprintf (stderr, " %s: interactive mode on.\n " , __func__);
224
255
225
- if (params.antiprompt .size ()) {
256
+ if (params.antiprompt .size ()) {
226
257
for (auto antiprompt : params.antiprompt ) {
227
258
fprintf (stderr, " Reverse prompt: '%s'\n " , antiprompt.c_str ());
228
259
}
@@ -232,14 +263,12 @@ int main(int argc, char ** argv) {
232
263
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
233
264
}
234
265
}
235
- fprintf (stderr, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
266
+ fprintf (stderr, " sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
267
+ fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
236
268
fprintf (stderr, " \n\n " );
237
269
238
- std::vector<llama_token> embd;
239
-
240
-
241
- int last_n_size = params.repeat_last_n ;
242
- std::vector<llama_token> last_n_tokens (last_n_size);
270
+ // TODO: replace with ring-buffer
271
+ std::vector<llama_token> last_n_tokens (n_ctx);
243
272
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
244
273
245
274
if (params.interactive ) {
@@ -252,27 +281,42 @@ int main(int argc, char ** argv) {
252
281
is_interacting = params.interactive_start || params.instruct ;
253
282
}
254
283
255
- int input_consumed = 0 ;
256
284
bool input_noecho = false ;
257
285
258
- int remaining_tokens = params.n_predict ;
286
+ int n_past = 0 ;
287
+ int n_remain = params.n_predict ;
288
+ int n_consumed = 0 ;
259
289
260
- #if defined (_WIN32)
261
- if (params.use_color ) {
262
- // Enable ANSI colors on Windows 10+
263
- unsigned long dwMode = 0 ;
264
- void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
265
- if (hConOut && hConOut != (void *)-1 && GetConsoleMode (hConOut, &dwMode) && !(dwMode & 0x4 )) {
266
- SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
267
- }
268
- }
269
- #endif
270
290
// the first thing we will do is to output the prompt, so set color accordingly
291
+ enable_console_colors ();
271
292
set_console_state (CONSOLE_STATE_PROMPT);
272
293
273
- while (remaining_tokens > 0 || params.interactive ) {
294
+ std::vector<llama_token> embd;
295
+
296
+ while (n_remain > 0 || params.interactive ) {
274
297
// predict
275
298
if (embd.size () > 0 ) {
299
+ // infinite text generation via context swapping
300
+ // if we run out of context:
301
+ // - take the n_keep first tokens from the original prompt (via n_past)
302
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
303
+ if (n_past + (int ) embd.size () > n_ctx) {
304
+ const int n_left = n_past - params.n_keep ;
305
+
306
+ n_past = params.n_keep ;
307
+
308
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
309
+ embd.insert (embd.begin (), last_n_tokens.begin () + n_ctx - n_left/2 - embd.size (), last_n_tokens.end () - embd.size ());
310
+
311
+ // printf("\n---\n");
312
+ // printf("resetting: '");
313
+ // for (int i = 0; i < (int) embd.size(); i++) {
314
+ // printf("%s", llama_token_to_str(ctx, embd[i]));
315
+ // }
316
+ // printf("'\n");
317
+ // printf("\n---\n");
318
+ }
319
+
276
320
if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads )) {
277
321
fprintf (stderr, " %s : failed to eval\n " , __func__);
278
322
return 1 ;
@@ -282,7 +326,7 @@ int main(int argc, char ** argv) {
282
326
n_past += embd.size ();
283
327
embd.clear ();
284
328
285
- if ((int ) embd_inp.size () <= input_consumed && !is_interacting) {
329
+ if ((int ) embd_inp.size () <= n_consumed && !is_interacting) {
286
330
// out of user input, sample next token
287
331
const float top_k = params.top_k ;
288
332
const float top_p = params.top_p ;
@@ -298,7 +342,9 @@ int main(int argc, char ** argv) {
298
342
logits[llama_token_eos ()] = 0 ;
299
343
}
300
344
301
- id = llama_sample_top_p_top_k (ctx, last_n_tokens.data (), last_n_tokens.size (), top_k, top_p, temp, repeat_penalty);
345
+ id = llama_sample_top_p_top_k (ctx,
346
+ last_n_tokens.data () + n_ctx - params.repeat_last_n ,
347
+ params.repeat_last_n , top_k, top_p, temp, repeat_penalty);
302
348
303
349
last_n_tokens.erase (last_n_tokens.begin ());
304
350
last_n_tokens.push_back (id);
@@ -321,14 +367,14 @@ int main(int argc, char ** argv) {
321
367
input_noecho = false ;
322
368
323
369
// decrement remaining sampling budget
324
- --remaining_tokens ;
370
+ --n_remain ;
325
371
} else {
326
372
// some user input remains from prompt or interaction, forward it to processing
327
- while ((int ) embd_inp.size () > input_consumed ) {
328
- embd.push_back (embd_inp[input_consumed ]);
373
+ while ((int ) embd_inp.size () > n_consumed ) {
374
+ embd.push_back (embd_inp[n_consumed ]);
329
375
last_n_tokens.erase (last_n_tokens.begin ());
330
- last_n_tokens.push_back (embd_inp[input_consumed ]);
331
- ++input_consumed ;
376
+ last_n_tokens.push_back (embd_inp[n_consumed ]);
377
+ ++n_consumed ;
332
378
if ((int ) embd.size () >= params.n_batch ) {
333
379
break ;
334
380
}
@@ -343,13 +389,13 @@ int main(int argc, char ** argv) {
343
389
fflush (stdout);
344
390
}
345
391
// reset color to default if we there is no pending user input
346
- if (!input_noecho && (int )embd_inp.size () == input_consumed ) {
392
+ if (!input_noecho && (int )embd_inp.size () == n_consumed ) {
347
393
set_console_state (CONSOLE_STATE_DEFAULT);
348
394
}
349
395
350
396
// in interactive mode, and not currently processing queued inputs;
351
397
// check if we should prompt the user for more
352
- if (params.interactive && (int ) embd_inp.size () <= input_consumed ) {
398
+ if (params.interactive && (int ) embd_inp.size () <= n_consumed ) {
353
399
// check for reverse prompt
354
400
std::string last_output;
355
401
for (auto id : last_n_tokens) {
@@ -371,7 +417,7 @@ int main(int argc, char ** argv) {
371
417
set_console_state (CONSOLE_STATE_USER_INPUT);
372
418
373
419
if (params.instruct ) {
374
- input_consumed = embd_inp.size ();
420
+ n_consumed = embd_inp.size ();
375
421
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
376
422
377
423
printf (" \n > " );
@@ -405,7 +451,7 @@ int main(int argc, char ** argv) {
405
451
embd_inp.insert (embd_inp.end (), inp_sfx.begin (), inp_sfx.end ());
406
452
}
407
453
408
- remaining_tokens -= line_inp.size ();
454
+ n_remain -= line_inp.size ();
409
455
410
456
input_noecho = true ; // do not echo this again
411
457
}
@@ -426,8 +472,8 @@ int main(int argc, char ** argv) {
426
472
}
427
473
428
474
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
429
- if (params.interactive && remaining_tokens <= 0 ) {
430
- remaining_tokens = params.n_predict ;
475
+ if (params.interactive && n_remain <= 0 ) {
476
+ n_remain = params.n_predict ;
431
477
is_interacting = true ;
432
478
}
433
479
}
0 commit comments