@@ -20,7 +20,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
20
20
const wchar_t * lpWideCharStr, int cchWideChar,
21
21
char * lpMultiByteStr, int cbMultiByte,
22
22
const char * lpDefaultChar, bool * lpUsedDefaultChar);
23
+ #define ENABLE_LINE_INPUT 0x0002
24
+ #define ENABLE_ECHO_INPUT 0x0004
23
25
#define CP_UTF8 65001
26
+ #define CONSOLE_CHAR_TYPE wchar_t
27
+ #define CONSOLE_GET_CHAR () getwchar()
28
+ #define CONSOLE_EOF WEOF
29
+ #else
30
+ #include < unistd.h>
31
+ #define CONSOLE_CHAR_TYPE char
32
+ #define CONSOLE_GET_CHAR () getchar()
33
+ #define CONSOLE_EOF EOF
24
34
#endif
25
35
26
36
bool gpt_params_parse (int argc, char ** argv, gpt_params & params) {
@@ -160,6 +170,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
160
170
params.interactive = true ;
161
171
} else if (arg == " --interactive-first" ) {
162
172
params.interactive_start = true ;
173
+ } else if (arg == " --author-mode" ) {
174
+ params.author_mode = true ;
163
175
} else if (arg == " -ins" || arg == " --instruct" ) {
164
176
params.instruct = true ;
165
177
} else if (arg == " --color" ) {
@@ -222,6 +234,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
222
234
fprintf (stderr, " -i, --interactive run in interactive mode\n " );
223
235
fprintf (stderr, " --interactive-first run in interactive mode and wait for input right away\n " );
224
236
fprintf (stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n " );
237
+ fprintf (stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\ '\n " );
225
238
fprintf (stderr, " -r PROMPT, --reverse-prompt PROMPT\n " );
226
239
fprintf (stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n " );
227
240
fprintf (stderr, " specified more than once for multiple prompts).\n " );
@@ -293,7 +306,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
293
306
}
294
307
295
308
/* Keep track of current color of output, and emit ANSI code if it changes. */
296
- void set_console_color (console_state & con_st, console_color_t color) {
309
+ void console_set_color (console_state & con_st, console_color_t color) {
297
310
if (con_st.use_color && con_st.color != color) {
298
311
switch (color) {
299
312
case CONSOLE_COLOR_DEFAULT:
@@ -310,8 +323,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
310
323
}
311
324
}
312
325
326
+ void console_init (console_state & con_st) {
313
327
#if defined (_WIN32)
314
- void win32_console_init ( bool enable_color) {
328
+ // Windows-specific console initialization
315
329
unsigned long dwMode = 0 ;
316
330
void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
317
331
if (!hConOut || hConOut == (void *)-1 || !GetConsoleMode (hConOut, &dwMode)) {
@@ -322,7 +336,7 @@ void win32_console_init(bool enable_color) {
322
336
}
323
337
if (hConOut) {
324
338
// Enable ANSI colors on Windows 10+
325
- if (enable_color && !(dwMode & 0x4 )) {
339
+ if (con_st. use_color && !(dwMode & 0x4 )) {
326
340
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
327
341
}
328
342
// Set console output codepage to UTF8
@@ -332,9 +346,46 @@ void win32_console_init(bool enable_color) {
332
346
if (hConIn && hConIn != (void *)-1 && GetConsoleMode (hConIn, &dwMode)) {
333
347
// Set console input codepage to UTF16
334
348
_setmode (_fileno (stdin), _O_WTEXT);
349
+
350
+ // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
351
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
352
+ SetConsoleMode (hConIn, dwMode);
353
+ }
354
+ #else
355
+ // POSIX-specific console initialization
356
+ struct termios new_termios;
357
+ tcgetattr (STDIN_FILENO, &con_st.prev_state );
358
+ new_termios = con_st.prev_state ;
359
+ new_termios.c_lflag &= ~(ICANON | ECHO);
360
+ new_termios.c_cc [VMIN] = 1 ;
361
+ new_termios.c_cc [VTIME] = 0 ;
362
+ tcsetattr (STDIN_FILENO, TCSANOW, &new_termios);
363
+ #endif
364
+ }
365
+
366
+ void console_cleanup (console_state & con_st) {
367
+ #if !defined(_WIN32)
368
+ // Restore the terminal settings on POSIX systems
369
+ tcsetattr (STDIN_FILENO, TCSANOW, &con_st.prev_state );
370
+ #endif
371
+
372
+ // Reset console color
373
+ console_set_color (con_st, CONSOLE_COLOR_DEFAULT);
374
+ }
375
+
376
+ // Helper function to remove the last UTF-8 character from a string
377
+ void remove_last_utf8_char (std::string & line) {
378
+ if (line.empty ()) return ;
379
+ size_t pos = line.length () - 1 ;
380
+
381
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
382
+ for (size_t i = 0 ; i < 3 && pos > 0 ; ++i, --pos) {
383
+ if ((line[pos] & 0xC0 ) != 0x80 ) break ; // Found the start of the character
335
384
}
385
+ line.erase (pos);
336
386
}
337
387
388
+ #if defined (_WIN32)
338
389
// Convert a wide Unicode string to an UTF8 string
339
390
void win32_utf8_encode (const std::wstring & wstr, std::string & str) {
340
391
int size_needed = WideCharToMultiByte (CP_UTF8, 0 , &wstr[0 ], (int )wstr.size (), NULL , 0 , NULL , NULL );
@@ -343,3 +394,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
343
394
str = strTo;
344
395
}
345
396
#endif
397
+
398
+ bool console_readline (console_state & con_st, std::string & line) {
399
+ line.clear ();
400
+ bool is_special_char = false ;
401
+ bool end_of_stream = false ;
402
+
403
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
404
+
405
+ CONSOLE_CHAR_TYPE input_char;
406
+ while (true ) {
407
+ fflush (stdout); // Ensure all output is displayed before waiting for input
408
+ input_char = CONSOLE_GET_CHAR ();
409
+
410
+ if (input_char == ' \r ' || input_char == ' \n ' ) {
411
+ break ;
412
+ }
413
+
414
+ if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/ ) {
415
+ end_of_stream = true ;
416
+ break ;
417
+ }
418
+
419
+ if (is_special_char) {
420
+ console_set_color (con_st, CONSOLE_COLOR_USER_INPUT);
421
+ putchar (' \b ' );
422
+ putchar (line.back ());
423
+ is_special_char = false ;
424
+ }
425
+
426
+ if (input_char == ' \033 ' ) { // Escape sequence
427
+ CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR ();
428
+ if (code == ' [' ) {
429
+ // Discard the rest of the escape sequence
430
+ while ((code = CONSOLE_GET_CHAR ()) != CONSOLE_EOF) {
431
+ if ((code >= ' A' && code <= ' Z' ) || (code >= ' a' && code <= ' z' ) || code == ' ~' ) {
432
+ break ;
433
+ }
434
+ }
435
+ }
436
+ } else if (input_char == 0x08 || input_char == 0x7F ) { // Backspace
437
+ if (!line.empty ()) {
438
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
439
+ remove_last_utf8_char (line);
440
+ }
441
+ } else if (input_char < 32 ) {
442
+ // Ignore control characters
443
+ } else {
444
+ #if defined(_WIN32)
445
+ std::string utf8_char;
446
+ win32_utf8_encode (std::wstring (1 , input_char), utf8_char);
447
+ line += utf8_char;
448
+ fputs (utf8_char.c_str (), stdout);
449
+ #else
450
+ line += input_char;
451
+ putchar (input_char);
452
+ #endif
453
+ }
454
+
455
+ if (!line.empty () && (line.back () == ' \\ ' || line.back () == ' /' )) {
456
+ console_set_color (con_st, CONSOLE_COLOR_PROMPT);
457
+ putchar (' \b ' );
458
+ putchar (line.back ());
459
+ is_special_char = true ;
460
+ }
461
+ }
462
+
463
+ bool has_more = con_st.author_mode ;
464
+ if (is_special_char) {
465
+ fputs (" \b \b " , stdout); // Move cursor back, print a space, and move cursor back again
466
+
467
+ char last = line.back ();
468
+ line.pop_back ();
469
+ if (last == ' \\ ' ) {
470
+ line += ' \n ' ;
471
+ putchar (' \n ' );
472
+ has_more = !has_more;
473
+ } else {
474
+ // llama doesn't seem to process a single space
475
+ if (line.length () == 1 && line.back () == ' ' ) {
476
+ line.clear ();
477
+ putchar (' \b ' );
478
+ }
479
+ has_more = false ;
480
+ }
481
+ } else {
482
+ if (end_of_stream) {
483
+ has_more = false ;
484
+ } else {
485
+ line += ' \n ' ;
486
+ putchar (' \n ' );
487
+ }
488
+ }
489
+
490
+ fflush (stdout);
491
+ return has_more;
492
+ }
0 commit comments