Skip to content

Commit 19fa30a

Browse files
committed
Remove direct access to std streams from "run"
The goal is to allow running "run" while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha <[email protected]>
1 parent f8f9b2e commit 19fa30a

File tree

3 files changed

+40
-30
lines changed

3 files changed

+40
-30
lines changed

main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "run.h"
22
#include "ggml.h"
33

4+
#include <iostream>
5+
46

57
std::vector<double> softmax(const std::vector<float>& logits) {
68
std::vector<double> probs(logits.size());
@@ -123,5 +125,5 @@ int main(int argc, char ** argv) {
123125
exit(0);
124126
}
125127

126-
return run(ctx, params);
128+
return run(ctx, params, std::cin, stdout, stderr);
127129
}

run.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,21 @@ enum console_state {
4444
static console_state con_st = CONSOLE_STATE_DEFAULT;
4545
static bool con_use_color = false;
4646

47-
void set_console_state(console_state new_st)
47+
void set_console_state(FILE *stream, console_state new_st)
4848
{
4949
if (!con_use_color) return;
5050
// only emit color code if state changed
5151
if (new_st != con_st) {
5252
con_st = new_st;
5353
switch(con_st) {
5454
case CONSOLE_STATE_DEFAULT:
55-
printf(ANSI_COLOR_RESET);
55+
fprintf(stream, ANSI_COLOR_RESET);
5656
return;
5757
case CONSOLE_STATE_PROMPT:
58-
printf(ANSI_COLOR_YELLOW);
58+
fprintf(stream, ANSI_COLOR_YELLOW);
5959
return;
6060
case CONSOLE_STATE_USER_INPUT:
61-
printf(ANSI_BOLD ANSI_COLOR_GREEN);
61+
fprintf(stream, ANSI_BOLD ANSI_COLOR_GREEN);
6262
return;
6363
}
6464
}
@@ -68,7 +68,7 @@ static bool is_interacting = false;
6868

6969
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
7070
void sigint_handler(int signo) {
71-
set_console_state(CONSOLE_STATE_DEFAULT);
71+
set_console_state(stdout, CONSOLE_STATE_DEFAULT);
7272
printf("\n"); // this also force flush stdout.
7373
if (signo == SIGINT) {
7474
if (!is_interacting) {
@@ -80,13 +80,17 @@ void sigint_handler(int signo) {
8080
}
8181
#endif
8282

83-
int run(llama_context * ctx, gpt_params params) {
83+
int run(llama_context * ctx,
84+
gpt_params params,
85+
std::istream & instream,
86+
FILE *outstream,
87+
FILE *errstream) {
8488

8589
if (params.seed <= 0) {
8690
params.seed = time(NULL);
8791
}
8892

89-
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
93+
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
9094

9195
std::mt19937 rng(params.seed);
9296
if (params.random_prompt) {
@@ -134,13 +138,13 @@ int run(llama_context * ctx, gpt_params params) {
134138
params.interactive = true;
135139
}
136140

137-
fprintf(stderr, "\n");
138-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
139-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
141+
fprintf(errstream, "\n");
142+
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
143+
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
140144
for (int i = 0; i < (int) embd_inp.size(); i++) {
141-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
145+
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
142146
}
143-
fprintf(stderr, "\n");
147+
fprintf(errstream, "\n");
144148
if (params.interactive) {
145149
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
146150
struct sigaction sigint_action;
@@ -152,16 +156,16 @@ int run(llama_context * ctx, gpt_params params) {
152156
signal(SIGINT, sigint_handler);
153157
#endif
154158

155-
fprintf(stderr, "%s: interactive mode on.\n", __func__);
159+
fprintf(errstream, "%s: interactive mode on.\n", __func__);
156160

157161
if(params.antiprompt.size()) {
158162
for (auto antiprompt : params.antiprompt) {
159-
fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
163+
fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str());
160164
}
161165
}
162166
}
163-
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
164-
fprintf(stderr, "\n\n");
167+
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
168+
fprintf(errstream, "\n\n");
165169

166170
std::vector<llama_token> embd;
167171

@@ -170,7 +174,7 @@ int run(llama_context * ctx, gpt_params params) {
170174
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
171175

172176
if (params.interactive) {
173-
fprintf(stderr, "== Running in interactive mode. ==\n"
177+
fprintf(errstream, "== Running in interactive mode. ==\n"
174178
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
175179
" - Press Ctrl+C to interject at any time.\n"
176180
#endif
@@ -195,13 +199,13 @@ int run(llama_context * ctx, gpt_params params) {
195199
}
196200
#endif
197201
// the first thing we will do is to output the prompt, so set color accordingly
198-
set_console_state(CONSOLE_STATE_PROMPT);
202+
set_console_state(outstream, CONSOLE_STATE_PROMPT);
199203

200204
while (remaining_tokens > 0 || params.interactive) {
201205
// predict
202206
if (embd.size() > 0) {
203207
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
204-
fprintf(stderr, "%s : failed to eval\n", __func__);
208+
fprintf(errstream, "%s : failed to eval\n", __func__);
205209
return 1;
206210
}
207211
}
@@ -259,13 +263,13 @@ int run(llama_context * ctx, gpt_params params) {
259263
// display text
260264
if (!input_noecho) {
261265
for (auto id : embd) {
262-
printf("%s", llama_token_to_str(ctx, id));
266+
fprintf(outstream, "%s", llama_token_to_str(ctx, id));
263267
}
264-
fflush(stdout);
268+
fflush(outstream);
265269
}
266270
// reset color to default if we there is no pending user input
267271
if (!input_noecho && (int)embd_inp.size() == input_consumed) {
268-
set_console_state(CONSOLE_STATE_DEFAULT);
272+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
269273
}
270274

271275
// in interactive mode, and not currently processing queued inputs;
@@ -286,20 +290,20 @@ int run(llama_context * ctx, gpt_params params) {
286290
}
287291
if (is_interacting) {
288292
// potentially set color to indicate we are taking user input
289-
set_console_state(CONSOLE_STATE_USER_INPUT);
293+
set_console_state(outstream, CONSOLE_STATE_USER_INPUT);
290294

291295
if (params.instruct) {
292296
input_consumed = embd_inp.size();
293297
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
294298

295-
printf("\n> ");
299+
fprintf(outstream, "\n> ");
296300
}
297301

298302
std::string buffer;
299303
std::string line;
300304
bool another_line = true;
301305
do {
302-
std::getline(std::cin, line);
306+
std::getline(instream, line);
303307
if (line.empty() || line.back() != '\\') {
304308
another_line = false;
305309
} else {
@@ -309,7 +313,7 @@ int run(llama_context * ctx, gpt_params params) {
309313
} while (another_line);
310314

311315
// done taking input, reset color
312-
set_console_state(CONSOLE_STATE_DEFAULT);
316+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
313317

314318
auto line_inp = ::llama_tokenize(ctx, buffer, false);
315319
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@@ -330,7 +334,7 @@ int run(llama_context * ctx, gpt_params params) {
330334
if (params.interactive) {
331335
is_interacting = true;
332336
} else {
333-
fprintf(stderr, " [end of text]\n");
337+
fprintf(errstream, " [end of text]\n");
334338
break;
335339
}
336340
}
@@ -350,7 +354,7 @@ int run(llama_context * ctx, gpt_params params) {
350354

351355
llama_free(ctx);
352356

353-
set_console_state(CONSOLE_STATE_DEFAULT);
357+
set_console_state(outstream, CONSOLE_STATE_DEFAULT);
354358

355359
return 0;
356360
}

run.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@
33
#include "llama.h"
44
#include "utils.h"
55

6-
int run(llama_context * ctx, gpt_params params);
6+
int run(llama_context * ctx,
7+
gpt_params params,
8+
std::istream & instream,
9+
FILE *outstream,
10+
FILE *errstream);

0 commit comments

Comments
 (0)