@@ -736,13 +736,16 @@ int llama_main(
736
736
gpt_vocab vocab,
737
737
llama_model model,
738
738
int64_t t_load_us,
739
- int64_t t_main_start_us) {
739
+ int64_t t_main_start_us,
740
+ std::istream & instream,
741
+ FILE *outstream,
742
+ FILE *errstream) {
740
743
741
744
if (params.seed < 0 ) {
742
745
params.seed = time (NULL );
743
746
}
744
747
745
- fprintf (stderr , " %s: seed = %d\n " , __func__, params.seed );
748
+ fprintf (errstream , " %s: seed = %d\n " , __func__, params.seed );
746
749
747
750
std::mt19937 rng (params.seed );
748
751
if (params.random_prompt ) {
@@ -788,13 +791,13 @@ int llama_main(
788
791
params.interactive = true ;
789
792
}
790
793
791
- fprintf (stderr , " \n " );
792
- fprintf (stderr , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
793
- fprintf (stderr , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
794
+ fprintf (errstream , " \n " );
795
+ fprintf (errstream , " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
796
+ fprintf (errstream , " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
794
797
for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
795
- fprintf (stderr , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
798
+ fprintf (errstream , " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
796
799
}
797
- fprintf (stderr , " \n " );
800
+ fprintf (errstream , " \n " );
798
801
if (params.interactive ) {
799
802
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
800
803
struct sigaction sigint_action;
@@ -806,22 +809,22 @@ int llama_main(
806
809
signal (SIGINT, sigint_handler);
807
810
#endif
808
811
809
- fprintf (stderr , " %s: interactive mode on.\n " , __func__);
812
+ fprintf (errstream , " %s: interactive mode on.\n " , __func__);
810
813
811
814
if (antipromptv_inp.size ()) {
812
815
for (size_t apindex = 0 ; apindex < antipromptv_inp.size (); ++apindex) {
813
816
auto antiprompt_inp = antipromptv_inp.at (apindex);
814
- fprintf (stderr , " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
815
- fprintf (stderr , " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
817
+ fprintf (errstream , " %s: reverse prompt: '%s'\n " , __func__, params.antiprompt .at (apindex).c_str ());
818
+ fprintf (errstream , " %s: number of tokens in reverse prompt = %zu\n " , __func__, antiprompt_inp.size ());
816
819
for (int i = 0 ; i < (int ) antiprompt_inp.size (); i++) {
817
- fprintf (stderr , " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
820
+ fprintf (errstream , " %6d -> '%s'\n " , antiprompt_inp[i], vocab.id_to_token .at (antiprompt_inp[i]).c_str ());
818
821
}
819
- fprintf (stderr , " \n " );
822
+ fprintf (errstream , " \n " );
820
823
}
821
824
}
822
825
}
823
- fprintf (stderr , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
824
- fprintf (stderr , " \n\n " );
826
+ fprintf (errstream , " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
827
+ fprintf (errstream , " \n\n " );
825
828
826
829
std::vector<gpt_vocab::id> embd;
827
830
@@ -834,7 +837,7 @@ int llama_main(
834
837
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
835
838
836
839
if (params.interactive ) {
837
- fprintf (stderr , " == Running in interactive mode. ==\n "
840
+ fprintf (errstream , " == Running in interactive mode. ==\n "
838
841
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
839
842
" - Press Ctrl+C to interject at any time.\n "
840
843
#endif
@@ -850,7 +853,7 @@ int llama_main(
850
853
851
854
// set the color for the prompt which will be output initially
852
855
if (params.use_color ) {
853
- printf ( ANSI_COLOR_YELLOW);
856
+ fprintf (outstream, ANSI_COLOR_YELLOW);
854
857
}
855
858
856
859
while (remaining_tokens > 0 || params.interactive ) {
@@ -859,7 +862,7 @@ int llama_main(
859
862
const int64_t t_start_us = ggml_time_us ();
860
863
861
864
if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
862
- fprintf (stderr , " Failed to predict\n " );
865
+ fprintf (errstream , " Failed to predict\n " );
863
866
return 1 ;
864
867
}
865
868
@@ -920,9 +923,9 @@ int llama_main(
920
923
// display text
921
924
if (!input_noecho) {
922
925
for (auto id : embd) {
923
- printf ( " %s" , vocab.id_to_token [id].c_str ());
926
+ fprintf (outstream, " %s" , vocab.id_to_token [id].c_str ());
924
927
}
925
- fflush (stdout );
928
+ fflush (outstream );
926
929
}
927
930
// reset color to default if we there is no pending user input
928
931
if (!input_noecho && params.use_color && (int )embd_inp.size () == input_consumed) {
@@ -954,7 +957,7 @@ int llama_main(
954
957
std::string line;
955
958
bool another_line = true ;
956
959
do {
957
- std::getline (std::cin , line);
960
+ std::getline (instream , line);
958
961
if (line.empty () || line.back () != ' \\ ' ) {
959
962
another_line = false ;
960
963
} else {
@@ -983,7 +986,7 @@ int llama_main(
983
986
if (params.interactive ) {
984
987
is_interacting = true ;
985
988
} else {
986
- fprintf (stderr , " [end of text]\n " );
989
+ fprintf (errstream , " [end of text]\n " );
987
990
break ;
988
991
}
989
992
}
@@ -1003,18 +1006,18 @@ int llama_main(
1003
1006
{
1004
1007
const int64_t t_main_end_us = ggml_time_us ();
1005
1008
1006
- fprintf (stderr , " \n\n " );
1007
- fprintf (stderr , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
1008
- fprintf (stderr , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
1009
- fprintf (stderr , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
1010
- fprintf (stderr , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
1011
- fprintf (stderr , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
1009
+ fprintf (errstream , " \n\n " );
1010
+ fprintf (errstream , " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
1011
+ fprintf (errstream , " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
1012
+ fprintf (errstream , " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
1013
+ fprintf (errstream , " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
1014
+ fprintf (errstream , " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
1012
1015
}
1013
1016
1014
1017
ggml_free (model.ctx );
1015
1018
1016
1019
if (params.use_color ) {
1017
- printf ( ANSI_COLOR_RESET);
1020
+ fprintf (outstream, ANSI_COLOR_RESET);
1018
1021
}
1019
1022
1020
1023
return 0 ;
0 commit comments