Skip to content

Commit 9ed33b3

Browse files
committed
Remove direct access to std streams from llama_main
The goal is to allow running llama_main while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha <[email protected]>
1 parent 8b9a9dc commit 9ed33b3

File tree

3 files changed

+38
-30
lines changed

3 files changed

+38
-30
lines changed

llama.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -736,13 +736,16 @@ int llama_main(
736736
gpt_vocab vocab,
737737
llama_model model,
738738
int64_t t_load_us,
739-
int64_t t_main_start_us) {
739+
int64_t t_main_start_us,
740+
std::istream & instream,
741+
FILE *outstream,
742+
FILE *errstream) {
740743

741744
if (params.seed < 0) {
742745
params.seed = time(NULL);
743746
}
744747

745-
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
748+
fprintf(errstream, "%s: seed = %d\n", __func__, params.seed);
746749

747750
std::mt19937 rng(params.seed);
748751
if (params.random_prompt) {
@@ -788,13 +791,13 @@ int llama_main(
788791
params.interactive = true;
789792
}
790793

791-
fprintf(stderr, "\n");
792-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
793-
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
794+
fprintf(errstream, "\n");
795+
fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
796+
fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
794797
for (int i = 0; i < (int) embd_inp.size(); i++) {
795-
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
798+
fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
796799
}
797-
fprintf(stderr, "\n");
800+
fprintf(errstream, "\n");
798801
if (params.interactive) {
799802
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
800803
struct sigaction sigint_action;
@@ -806,22 +809,22 @@ int llama_main(
806809
signal(SIGINT, sigint_handler);
807810
#endif
808811

809-
fprintf(stderr, "%s: interactive mode on.\n", __func__);
812+
fprintf(errstream, "%s: interactive mode on.\n", __func__);
810813

811814
if(antipromptv_inp.size()) {
812815
for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
813816
auto antiprompt_inp = antipromptv_inp.at(apindex);
814-
fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
815-
fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
817+
fprintf(errstream, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
818+
fprintf(errstream, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
816819
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
817-
fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
820+
fprintf(errstream, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
818821
}
819-
fprintf(stderr, "\n");
822+
fprintf(errstream, "\n");
820823
}
821824
}
822825
}
823-
fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
824-
fprintf(stderr, "\n\n");
826+
fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
827+
fprintf(errstream, "\n\n");
825828

826829
std::vector<gpt_vocab::id> embd;
827830

@@ -834,7 +837,7 @@ int llama_main(
834837
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
835838

836839
if (params.interactive) {
837-
fprintf(stderr, "== Running in interactive mode. ==\n"
840+
fprintf(errstream, "== Running in interactive mode. ==\n"
838841
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
839842
" - Press Ctrl+C to interject at any time.\n"
840843
#endif
@@ -850,7 +853,7 @@ int llama_main(
850853

851854
// set the color for the prompt which will be output initially
852855
if (params.use_color) {
853-
printf(ANSI_COLOR_YELLOW);
856+
fprintf(outstream, ANSI_COLOR_YELLOW);
854857
}
855858

856859
while (remaining_tokens > 0 || params.interactive) {
@@ -859,7 +862,7 @@ int llama_main(
859862
const int64_t t_start_us = ggml_time_us();
860863

861864
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) {
862-
fprintf(stderr, "Failed to predict\n");
865+
fprintf(errstream, "Failed to predict\n");
863866
return 1;
864867
}
865868

@@ -920,9 +923,9 @@ int llama_main(
920923
// display text
921924
if (!input_noecho) {
922925
for (auto id : embd) {
923-
printf("%s", vocab.id_to_token[id].c_str());
926+
fprintf(outstream, "%s", vocab.id_to_token[id].c_str());
924927
}
925-
fflush(stdout);
928+
fflush(outstream);
926929
}
927930
// reset color to default if we there is no pending user input
928931
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
@@ -954,7 +957,7 @@ int llama_main(
954957
std::string line;
955958
bool another_line = true;
956959
do {
957-
std::getline(std::cin, line);
960+
std::getline(instream, line);
958961
if (line.empty() || line.back() != '\\') {
959962
another_line = false;
960963
} else {
@@ -983,7 +986,7 @@ int llama_main(
983986
if (params.interactive) {
984987
is_interacting = true;
985988
} else {
986-
fprintf(stderr, " [end of text]\n");
989+
fprintf(errstream, " [end of text]\n");
987990
break;
988991
}
989992
}
@@ -1003,18 +1006,18 @@ int llama_main(
10031006
{
10041007
const int64_t t_main_end_us = ggml_time_us();
10051008

1006-
fprintf(stderr, "\n\n");
1007-
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
1008-
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
1009-
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
1010-
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
1011-
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
1009+
fprintf(errstream, "\n\n");
1010+
fprintf(errstream, "%s: mem per token = %8zu bytes\n", __func__, mem_per_token);
1011+
fprintf(errstream, "%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f);
1012+
fprintf(errstream, "%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f);
1013+
fprintf(errstream, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past);
1014+
fprintf(errstream, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
10121015
}
10131016

10141017
ggml_free(model.ctx);
10151018

10161019
if (params.use_color) {
1017-
printf(ANSI_COLOR_RESET);
1020+
fprintf(outstream, ANSI_COLOR_RESET);
10181021
}
10191022

10201023
return 0;

llama.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,8 @@ int llama_main(
6464
gpt_vocab vocab,
6565
llama_model model,
6666
int64_t t_load_us,
67-
int64_t t_main_start_us);
67+
int64_t t_main_start_us,
68+
std::istream & instream,
69+
FILE *outstream,
70+
FILE *errstream);
6871
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type);

main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "utils.h"
33
#include "llama.h"
44

5+
#include <iostream>
6+
57
const char * llama_print_system_info(void) {
68
static std::string s;
79

@@ -63,5 +65,5 @@ int main(int argc, char ** argv) {
6365
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
6466
}
6567

66-
return llama_main(params, vocab, model, t_main_start_us, t_load_us);
68+
return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
6769
}

0 commit comments

Comments
 (0)