Skip to content

Commit 910aedf

Browse files
committed
Add author mode and other related QOL improvements
1 parent 5f93949 commit 910aedf

File tree

3 files changed

+190
-56
lines changed

3 files changed

+190
-56
lines changed

examples/common.cpp

+150-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@ extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int
2020
const wchar_t * lpWideCharStr, int cchWideChar,
2121
char * lpMultiByteStr, int cbMultiByte,
2222
const char * lpDefaultChar, bool * lpUsedDefaultChar);
23+
#define ENABLE_LINE_INPUT 0x0002
24+
#define ENABLE_ECHO_INPUT 0x0004
2325
#define CP_UTF8 65001
26+
#define CONSOLE_CHAR_TYPE wchar_t
27+
#define CONSOLE_GET_CHAR() getwchar()
28+
#define CONSOLE_EOF WEOF
29+
#else
30+
#include <unistd.h>
31+
#define CONSOLE_CHAR_TYPE char
32+
#define CONSOLE_GET_CHAR() getchar()
33+
#define CONSOLE_EOF EOF
2434
#endif
2535

2636
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
@@ -160,6 +170,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
160170
params.interactive = true;
161171
} else if (arg == "--interactive-first") {
162172
params.interactive_start = true;
173+
} else if (arg == "--author-mode") {
174+
params.author_mode = true;
163175
} else if (arg == "-ins" || arg == "--instruct") {
164176
params.instruct = true;
165177
} else if (arg == "--color") {
@@ -222,6 +234,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
222234
fprintf(stderr, " -i, --interactive run in interactive mode\n");
223235
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
224236
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
237+
fprintf(stderr, " --author-mode allows you to write or paste multiple lines without ending each in '\\'\n");
225238
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
226239
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
227240
fprintf(stderr, " specified more than once for multiple prompts).\n");
@@ -293,7 +306,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
293306
}
294307

295308
/* Keep track of current color of output, and emit ANSI code if it changes. */
296-
void set_console_color(console_state & con_st, console_color_t color) {
309+
void console_set_color(console_state & con_st, console_color_t color) {
297310
if (con_st.use_color && con_st.color != color) {
298311
switch(color) {
299312
case CONSOLE_COLOR_DEFAULT:
@@ -310,8 +323,9 @@ void set_console_color(console_state & con_st, console_color_t color) {
310323
}
311324
}
312325

326+
void console_init(console_state & con_st) {
313327
#if defined (_WIN32)
314-
void win32_console_init(bool enable_color) {
328+
// Windows-specific console initialization
315329
unsigned long dwMode = 0;
316330
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
317331
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
@@ -322,7 +336,7 @@ void win32_console_init(bool enable_color) {
322336
}
323337
if (hConOut) {
324338
// Enable ANSI colors on Windows 10+
325-
if (enable_color && !(dwMode & 0x4)) {
339+
if (con_st.use_color && !(dwMode & 0x4)) {
326340
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
327341
}
328342
// Set console output codepage to UTF8
@@ -332,9 +346,46 @@ void win32_console_init(bool enable_color) {
332346
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
333347
// Set console input codepage to UTF16
334348
_setmode(_fileno(stdin), _O_WTEXT);
349+
350+
// Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
351+
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
352+
SetConsoleMode(hConIn, dwMode);
353+
}
354+
#else
355+
// POSIX-specific console initialization
356+
struct termios new_termios;
357+
tcgetattr(STDIN_FILENO, &con_st.prev_state);
358+
new_termios = con_st.prev_state;
359+
new_termios.c_lflag &= ~(ICANON | ECHO);
360+
new_termios.c_cc[VMIN] = 1;
361+
new_termios.c_cc[VTIME] = 0;
362+
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
363+
#endif
364+
}
365+
366+
void console_cleanup(console_state & con_st) {
367+
#if !defined(_WIN32)
368+
// Restore the terminal settings on POSIX systems
369+
tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state);
370+
#endif
371+
372+
// Reset console color
373+
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
374+
}
375+
376+
// Helper function to remove the last UTF-8 character from a string
377+
void remove_last_utf8_char(std::string & line) {
378+
if (line.empty()) return;
379+
size_t pos = line.length() - 1;
380+
381+
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
382+
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
383+
if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character
335384
}
385+
line.erase(pos);
336386
}
337387

388+
#if defined (_WIN32)
338389
// Convert a wide Unicode string to an UTF8 string
339390
void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
340391
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL);
@@ -343,3 +394,99 @@ void win32_utf8_encode(const std::wstring & wstr, std::string & str) {
343394
str = strTo;
344395
}
345396
#endif
397+
398+
bool console_readline(console_state & con_st, std::string & line) {
399+
line.clear();
400+
bool is_special_char = false;
401+
bool end_of_stream = false;
402+
403+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
404+
405+
CONSOLE_CHAR_TYPE input_char;
406+
while (true) {
407+
fflush(stdout); // Ensure all output is displayed before waiting for input
408+
input_char = CONSOLE_GET_CHAR();
409+
410+
if (input_char == '\r' || input_char == '\n') {
411+
break;
412+
}
413+
414+
if (input_char == CONSOLE_EOF || input_char == 0x04 /* Ctrl+D*/) {
415+
end_of_stream = true;
416+
break;
417+
}
418+
419+
if (is_special_char) {
420+
console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
421+
putchar('\b');
422+
putchar(line.back());
423+
is_special_char = false;
424+
}
425+
426+
if (input_char == '\033') { // Escape sequence
427+
CONSOLE_CHAR_TYPE code = CONSOLE_GET_CHAR();
428+
if (code == '[') {
429+
// Discard the rest of the escape sequence
430+
while ((code = CONSOLE_GET_CHAR()) != CONSOLE_EOF) {
431+
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
432+
break;
433+
}
434+
}
435+
}
436+
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
437+
if (!line.empty()) {
438+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
439+
remove_last_utf8_char(line);
440+
}
441+
} else if (input_char < 32) {
442+
// Ignore control characters
443+
} else {
444+
#if defined(_WIN32)
445+
std::string utf8_char;
446+
win32_utf8_encode(std::wstring(1, input_char), utf8_char);
447+
line += utf8_char;
448+
fputs(utf8_char.c_str(), stdout);
449+
#else
450+
line += input_char;
451+
putchar(input_char);
452+
#endif
453+
}
454+
455+
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
456+
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
457+
putchar('\b');
458+
putchar(line.back());
459+
is_special_char = true;
460+
}
461+
}
462+
463+
bool has_more = con_st.author_mode;
464+
if (is_special_char) {
465+
fputs("\b \b", stdout); // Move cursor back, print a space, and move cursor back again
466+
467+
char last = line.back();
468+
line.pop_back();
469+
if (last == '\\') {
470+
line += '\n';
471+
putchar('\n');
472+
has_more = !has_more;
473+
} else {
474+
// llama doesn't seem to process a single space
475+
if (line.length() == 1 && line.back() == ' ') {
476+
line.clear();
477+
putchar('\b');
478+
}
479+
has_more = false;
480+
}
481+
} else {
482+
if (end_of_stream) {
483+
has_more = false;
484+
} else {
485+
line += '\n';
486+
putchar('\n');
487+
}
488+
}
489+
490+
fflush(stdout);
491+
return has_more;
492+
}

examples/common.h

+13-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include <random>
1010
#include <thread>
1111

12+
#if !defined (_WIN32)
13+
#include <termios.h>
14+
#endif
15+
1216
//
1317
// CLI argument parsing
1418
//
@@ -46,6 +50,7 @@ struct gpt_params {
4650
bool interactive_start = false; // wait for user input immediately
4751

4852
bool instruct = false; // instruction mode (used for Alpaca models)
53+
bool author_mode = false; // reverse the usage of `\`
4954
bool ignore_eos = false; // do not stop generating after eos
5055
bool perplexity = false; // compute perplexity over the prompt
5156
bool use_mmap = true; // use mmap for faster loads
@@ -86,13 +91,15 @@ enum console_color_t {
8691
};
8792

8893
struct console_state {
94+
bool author_mode = false;
8995
bool use_color = false;
9096
console_color_t color = CONSOLE_COLOR_DEFAULT;
97+
#if !defined (_WIN32)
98+
termios prev_state;
99+
#endif
91100
};
92101

93-
void set_console_color(console_state & con_st, console_color_t color);
94-
95-
#if defined (_WIN32)
96-
void win32_console_init(bool enable_color);
97-
void win32_utf8_encode(const std::wstring & wstr, std::string & str);
98-
#endif
102+
void console_init(console_state & con_st);
103+
void console_cleanup(console_state & con_st);
104+
void console_set_color(console_state & con_st, console_color_t color);
105+
bool console_readline(console_state & con_st, std::string & line);

0 commit comments

Comments
 (0)