@@ -44,21 +44,21 @@ enum console_state {
44
44
static console_state con_st = CONSOLE_STATE_DEFAULT;
45
45
static bool con_use_color = false ;
46
46
47
- void set_console_state (console_state new_st)
47
+ void set_console_state (FILE *stream, console_state new_st)
48
48
{
49
49
if (!con_use_color) return ;
50
50
// only emit color code if state changed
51
51
if (new_st != con_st) {
52
52
con_st = new_st;
53
53
switch (con_st) {
54
54
case CONSOLE_STATE_DEFAULT:
55
- printf ( ANSI_COLOR_RESET);
55
+ fprintf (stream, ANSI_COLOR_RESET);
56
56
return ;
57
57
case CONSOLE_STATE_PROMPT:
58
- printf ( ANSI_COLOR_YELLOW);
58
+ fprintf (stream, ANSI_COLOR_YELLOW);
59
59
return ;
60
60
case CONSOLE_STATE_USER_INPUT:
61
- printf ( ANSI_BOLD ANSI_COLOR_GREEN);
61
+ fprintf (stream, ANSI_BOLD ANSI_COLOR_GREEN);
62
62
return ;
63
63
}
64
64
}
@@ -68,7 +68,7 @@ static bool is_interacting = false;
68
68
69
69
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
70
70
void sigint_handler (int signo) {
71
- set_console_state (CONSOLE_STATE_DEFAULT);
71
+ set_console_state (stdout, CONSOLE_STATE_DEFAULT);
72
72
printf (" \n " ); // this also force flush stdout.
73
73
if (signo == SIGINT) {
74
74
if (!is_interacting) {
@@ -80,13 +80,17 @@ void sigint_handler(int signo) {
80
80
}
81
81
#endif
82
82
83
- int run (llama_context * ctx, gpt_params params) {
83
+ int run (llama_context * ctx,
84
+ gpt_params params,
85
+ std::istream & instream,
86
+ FILE *outstream,
87
+ FILE *errstream) {
84
88
85
89
if (params.seed <= 0 ) {
86
90
params.seed = time (NULL );
87
91
}
88
92
89
- fprintf (stderr , " %s: seed = %d\n " , __func__, params.seed );
93
+ fprintf (errstream , " %s: seed = %d\n " , __func__, params.seed );
90
94
91
95
std::mt19937 rng (params.seed );
92
96
if (params.random_prompt ) {
@@ -134,13 +138,13 @@ int run(llama_context * ctx, gpt_params params) {
134
138
params.interactive = true ;
135
139
}
136
140
137
- fprintf (stderr , " \n " );
138
- fprintf (stderr , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
139
- fprintf (stderr , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
141
+ fprintf (errstream , " \n " );
142
+ fprintf (errstream , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
143
+ fprintf (errstream , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
140
144
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
141
- fprintf (stderr , " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
145
+ fprintf (errstream , " %6d -> '%s'\n " , embd_inp[i], llama_token_to_str (ctx, embd_inp[i]));
142
146
}
143
- fprintf (stderr , " \n " );
147
+ fprintf (errstream , " \n " );
144
148
if (params.interactive ) {
145
149
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
146
150
struct sigaction sigint_action;
@@ -152,16 +156,16 @@ int run(llama_context * ctx, gpt_params params) {
152
156
signal (SIGINT, sigint_handler);
153
157
#endif
154
158
155
- fprintf (stderr , " %s: interactive mode on.\n " , __func__);
159
+ fprintf (errstream , " %s: interactive mode on.\n " , __func__);
156
160
157
161
if (params.antiprompt .size ()) {
158
162
for (auto antiprompt : params.antiprompt ) {
159
- fprintf (stderr , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
163
+ fprintf (errstream , " Reverse prompt: '%s'\n " , antiprompt.c_str ());
160
164
}
161
165
}
162
166
}
163
- fprintf (stderr , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
164
- fprintf (stderr , " \n\n " );
167
+ fprintf (errstream , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
168
+ fprintf (errstream , " \n\n " );
165
169
166
170
std::vector<llama_token> embd;
167
171
@@ -170,7 +174,7 @@ int run(llama_context * ctx, gpt_params params) {
170
174
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
171
175
172
176
if (params.interactive ) {
173
- fprintf (stderr , " == Running in interactive mode. ==\n "
177
+ fprintf (errstream , " == Running in interactive mode. ==\n "
174
178
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
175
179
" - Press Ctrl+C to interject at any time.\n "
176
180
#endif
@@ -195,13 +199,13 @@ int run(llama_context * ctx, gpt_params params) {
195
199
}
196
200
#endif
197
201
// the first thing we will do is to output the prompt, so set color accordingly
198
- set_console_state (CONSOLE_STATE_PROMPT);
202
+ set_console_state (outstream, CONSOLE_STATE_PROMPT);
199
203
200
204
while (remaining_tokens > 0 || params.interactive ) {
201
205
// predict
202
206
if (embd.size () > 0 ) {
203
207
if (llama_eval (ctx, embd.data (), embd.size (), n_past, params.n_threads )) {
204
- fprintf (stderr , " %s : failed to eval\n " , __func__);
208
+ fprintf (errstream , " %s : failed to eval\n " , __func__);
205
209
return 1 ;
206
210
}
207
211
}
@@ -259,13 +263,13 @@ int run(llama_context * ctx, gpt_params params) {
259
263
// display text
260
264
if (!input_noecho) {
261
265
for (auto id : embd) {
262
- printf ( " %s" , llama_token_to_str (ctx, id));
266
+ fprintf (outstream, " %s" , llama_token_to_str (ctx, id));
263
267
}
264
- fflush (stdout );
268
+ fflush (outstream );
265
269
}
266
270
// reset color to default if we there is no pending user input
267
271
if (!input_noecho && (int )embd_inp.size () == input_consumed) {
268
- set_console_state (CONSOLE_STATE_DEFAULT);
272
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
269
273
}
270
274
271
275
// in interactive mode, and not currently processing queued inputs;
@@ -286,20 +290,20 @@ int run(llama_context * ctx, gpt_params params) {
286
290
}
287
291
if (is_interacting) {
288
292
// potentially set color to indicate we are taking user input
289
- set_console_state (CONSOLE_STATE_USER_INPUT);
293
+ set_console_state (outstream, CONSOLE_STATE_USER_INPUT);
290
294
291
295
if (params.instruct ) {
292
296
input_consumed = embd_inp.size ();
293
297
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
294
298
295
- printf ( " \n > " );
299
+ fprintf (outstream, " \n > " );
296
300
}
297
301
298
302
std::string buffer;
299
303
std::string line;
300
304
bool another_line = true ;
301
305
do {
302
- std::getline (std::cin , line);
306
+ std::getline (instream , line);
303
307
if (line.empty () || line.back () != ' \\ ' ) {
304
308
another_line = false ;
305
309
} else {
@@ -309,7 +313,7 @@ int run(llama_context * ctx, gpt_params params) {
309
313
} while (another_line);
310
314
311
315
// done taking input, reset color
312
- set_console_state (CONSOLE_STATE_DEFAULT);
316
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
313
317
314
318
auto line_inp = ::llama_tokenize (ctx, buffer, false );
315
319
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
@@ -330,7 +334,7 @@ int run(llama_context * ctx, gpt_params params) {
330
334
if (params.interactive ) {
331
335
is_interacting = true ;
332
336
} else {
333
- fprintf (stderr , " [end of text]\n " );
337
+ fprintf (errstream , " [end of text]\n " );
334
338
break ;
335
339
}
336
340
}
@@ -350,7 +354,7 @@ int run(llama_context * ctx, gpt_params params) {
350
354
351
355
llama_free (ctx);
352
356
353
- set_console_state (CONSOLE_STATE_DEFAULT);
357
+ set_console_state (outstream, CONSOLE_STATE_DEFAULT);
354
358
355
359
return 0 ;
356
360
}
0 commit comments