diff --git a/CMakeLists.txt b/CMakeLists.txt index d952afb4ff72b..d95d93f99c27c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,9 +239,15 @@ target_link_libraries(llama PRIVATE utils ggml ${LLAMA_EXTRA_LIBS}) # Executables # -add_executable(main main.cpp) +add_executable(main + main.cpp + run.cpp) target_link_libraries(main PRIVATE llama ggml utils) +if(NOT WIN32) + target_sources(main PRIVATE tcp_server.cpp) +endif() + add_executable(quantize quantize.cpp) target_link_libraries(quantize PRIVATE llama ggml utils) diff --git a/Makefile b/Makefile index edb0c64c82361..59400a8033f34 100644 --- a/Makefile +++ b/Makefile @@ -226,11 +226,17 @@ llama.o: llama.cpp llama.h utils.o: utils.cpp utils.h $(CXX) $(CXXFLAGS) -c utils.cpp -o utils.o +run.o: run.cpp run.h + $(CXX) $(CXXFLAGS) -c run.cpp -o run.o + +tcp_server.o: tcp_server.cpp tcp_server.h + $(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o + clean: rm -f *.o main quantize -main: main.cpp ggml.o llama.o utils.o - $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o -o main $(LDFLAGS) +main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o + $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS) @echo "\x1b[36mrun ./main -h for help\x1b[0m" quantize: quantize.cpp ggml.o llama.o utils.o diff --git a/chat_tcp_client.sh b/chat_tcp_client.sh new file mode 100755 index 0000000000000..f154ae57dc4a6 --- /dev/null +++ b/chat_tcp_client.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User:Hello, Bob. +Bob:Hello. How may I help you today? +User:Please tell me the largest city in Europe. +Bob:Sure. The largest city in Europe is Moscow, the capital of Russia. +User:"}" +RPROMPT="${RPROMPT:-"User:"}" +N_PREDICT="${N_PREDICT:-"4096"}" +REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}" +N_THREADS="${N_THREADS:-"4"}" + +# Open connection to the chat server +exec 3<>/dev/tcp/127.0.0.1/${PORT} + +# Pass the arguments. The protocol is really simple: +# 1. Pass the number of arguments followed by a linefeed +# 2. Pass the arguments, with each being followed by "0" +( +echo -en "12\n" +echo -en "-t\x00" +echo -en "$N_THREADS\x00" +echo -en "-n\x00" +echo -en "$N_PREDICT\x00" +echo -en "--repeat_penalty\x00" +echo -en "$REPEAT_PENALTY\x00" +echo -en "--color\x00" +echo -en "-i\x00" +echo -en "-r\x00" +echo -en "$RPROMPT\x00" +echo -en "-p\x00" +echo -en "$PROMPT\x00" +) >&3 + +trap exit TERM + +# When we have passed the arguments, start printing socket data to the screen. +# This is done in a background job because we also want to send data when +# running in interactive mode. +cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" & +cat >&3 +wait diff --git a/chat_tcp_server.sh b/chat_tcp_server.sh new file mode 100755 index 0000000000000..79320906d7b0b --- /dev/null +++ b/chat_tcp_server.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin} + +./main -l ${PORT} -m $MODEL diff --git a/main.cpp b/main.cpp index 4569ef2a11fbb..975714f9382f7 100644 --- a/main.cpp +++ b/main.cpp @@ -1,69 +1,9 @@ -#include "utils.h" +#include "run.h" #include "ggml.h" -#include "llama.h" +#include "tcp_server.h" -#include -#include -#include -#include -#include -#include #include -#include -#include -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#include -#endif - -#if defined (_WIN32) -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -#endif - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" - -/* Keep track of current color of output, and emit ANSI code if it changes. */ -enum console_state { - CONSOLE_STATE_DEFAULT=0, - CONSOLE_STATE_PROMPT, - CONSOLE_STATE_USER_INPUT -}; - -static console_state con_st = CONSOLE_STATE_DEFAULT; -static bool con_use_color = false; - -void set_console_state(console_state new_st) -{ - if (!con_use_color) return; - // only emit color code if state changed - if (new_st != con_st) { - con_st = new_st; - switch(con_st) { - case CONSOLE_STATE_DEFAULT: - printf(ANSI_COLOR_RESET); - return; - case CONSOLE_STATE_PROMPT: - printf(ANSI_COLOR_YELLOW); - return; - case CONSOLE_STATE_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - return; - } - } -} std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); @@ -138,22 +78,6 @@ void perplexity(llama_context * ctx, const gpt_params & params) { printf("\n"); } -static bool is_interacting = false; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) -void sigint_handler(int signo) { - set_console_state(CONSOLE_STATE_DEFAULT); - printf("\n"); // this also force flush stdout. - if (signo == SIGINT) { - if (!is_interacting) { - is_interacting=true; - } else { - _exit(130); - } - } -} -#endif - int main(int argc, char ** argv) { // has to be called once at the start of the program to init ggml stuff ggml_time_init(); @@ -170,24 +94,6 @@ int main(int argc, char ** argv) { "expect poor results\n", __func__, params.n_ctx); } - if (params.seed <= 0) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - - std::mt19937 rng(params.seed); - if (params.random_prompt) { - params.prompt = gpt_random_prompt(rng); - } - - // save choice to use color for later - // (note for later: this is a slightly awkward choice) - con_use_color = params.use_color; - -// params.prompt = R"(// this function checks if the number n is prime -//bool is_prime(int n) {)"; - llama_context * ctx; // load the model @@ -215,266 +121,16 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // determine the required inference memory per token: - // TODO: better way to do that - { - const std::vector tmp = { 0, 1, 2, 3 }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - } - if (params.perplexity) { perplexity(ctx, params); exit(0); } - int n_past = 0; - - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); - - // tokenize the prompt - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); - - const int n_ctx = llama_n_ctx(ctx); - - params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size()); - - // prefix & suffix for instruct mode - const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); - const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); - - // in instruct mode, we inject a prefix and a suffix to each input by the user - if (params.instruct) { - params.interactive = true; - params.antiprompt.push_back("### Instruction:\n\n"); +#ifndef _WIN32 + if (params.listen_port != "") { + return listen_tcp(ctx, params); } - - // enable interactive mode if reverse prompt is specified - if (params.antiprompt.size() != 0) { - params.interactive = true; - } - - if (params.interactive_start) { - params.interactive = true; - } - - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); - } - fprintf(stderr, "\n"); - if (params.interactive) { -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined (_WIN32) - signal(SIGINT, sigint_handler); #endif - fprintf(stderr, "%s: interactive mode on.\n", __func__); - - if(params.antiprompt.size()) { - for (auto antiprompt : params.antiprompt) { - fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); - } - } - } - fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - fprintf(stderr, "\n\n"); - - std::vector embd; - - int last_n_size = params.repeat_last_n; - std::vector last_n_tokens(last_n_size); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - - if (params.interactive) { - fprintf(stderr, "== Running in interactive mode. ==\n" -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) - " - Press Ctrl+C to interject at any time.\n" -#endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); - is_interacting = params.interactive_start; - } - - int input_consumed = 0; - bool input_noecho = false; - - int remaining_tokens = params.n_predict; - -#if defined (_WIN32) - if (params.use_color) { - // Enable ANSI colors on Windows 10+ - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (hConOut && hConOut != (void*)-1 && GetConsoleMode(hConOut, &dwMode) && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) - } - } -#endif - // the first thing we will do is to output the prompt, so set color accordingly - set_console_state(CONSOLE_STATE_PROMPT); - - while (remaining_tokens > 0 || params.interactive) { - // predict - if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - } - - n_past += embd.size(); - embd.clear(); - - if ((int) embd_inp.size() <= input_consumed) { - // out of user input, sample next token - const float top_k = params.top_k; - const float top_p = params.top_p; - const float temp = params.temp; - const float repeat_penalty = params.repeat_penalty; - - llama_token id = 0; - - { - auto logits = llama_get_logits(ctx); - - if (params.ignore_eos) { - // set the logit of the eos token to zero to avoid sampling it - //logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; - // TODO: this does not work of params.logits_all == true - assert(params.perplexity == false); - logits[llama_token_eos()] = 0; - } - - id = llama_sample_top_p_top_k(ctx, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty); - - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); - } - - // add it to the context - embd.push_back(id); - - // echo this to console - input_noecho = false; - - // decrement remaining sampling budget - --remaining_tokens; - } else { - // some user input remains from prompt or interaction, forward it to processing - while ((int) embd_inp.size() > input_consumed) { - embd.push_back(embd_inp[input_consumed]); - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(embd_inp[input_consumed]); - ++input_consumed; - if ((int) embd.size() >= params.n_batch) { - break; - } - } - } - - // display text - if (!input_noecho) { - for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); - } - fflush(stdout); - } - // reset color to default if we there is no pending user input - if (!input_noecho && (int)embd_inp.size() == input_consumed) { - set_console_state(CONSOLE_STATE_DEFAULT); - } - - // in interactive mode, and not currently processing queued inputs; - // check if we should prompt the user for more - if (params.interactive && (int) embd_inp.size() <= input_consumed) { - // check for reverse prompt - std::string last_output; - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } - - // Check if each of the reverse prompts appears at the end of the output. - for (std::string antiprompt : params.antiprompt) { - if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; - break; - } - } - if (is_interacting) { - // potentially set color to indicate we are taking user input - set_console_state(CONSOLE_STATE_USER_INPUT); - - if (params.instruct) { - input_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - - printf("\n> "); - } - - std::string buffer; - std::string line; - bool another_line = true; - do { - std::getline(std::cin, line); - if (line.empty() || line.back() != '\\') { - another_line = false; - } else { - line.pop_back(); // Remove the continue character - } - buffer += line + '\n'; // Append the line to the result - } while (another_line); - - // done taking input, reset color - set_console_state(CONSOLE_STATE_DEFAULT); - - auto line_inp = ::llama_tokenize(ctx, buffer, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - } - - remaining_tokens -= line_inp.size(); - - input_noecho = true; // do not echo this again - } - is_interacting = false; - } - - // end of text token - if (embd.back() == llama_token_eos()) { - if (params.interactive) { - is_interacting = true; - } else { - fprintf(stderr, " [end of text]\n"); - break; - } - } - - // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. - if (params.interactive && remaining_tokens <= 0) { - remaining_tokens = params.n_predict; - is_interacting = true; - } - } - -#if defined (_WIN32) - signal(SIGINT, SIG_DFL); -#endif - - llama_print_timings(ctx); - - llama_free(ctx); - - set_console_state(CONSOLE_STATE_DEFAULT); - - return 0; + return run(ctx, params, std::cin, stdout, stderr); } diff --git a/run.cpp b/run.cpp new file mode 100644 index 0000000000000..ab430eb9291d8 --- /dev/null +++ b/run.cpp @@ -0,0 +1,364 @@ +#include "utils.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#include +#endif + +#if defined (_WIN32) +#pragma comment(lib,"kernel32.lib") +extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); +extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); +extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); +#endif + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +enum console_state { + CONSOLE_STATE_DEFAULT=0, + CONSOLE_STATE_PROMPT, + CONSOLE_STATE_USER_INPUT +}; + +static console_state con_st = CONSOLE_STATE_DEFAULT; +static bool con_use_color = false; + +void set_console_state(FILE *stream, console_state new_st) +{ + if (!con_use_color) return; + // only emit color code if state changed + if (new_st != con_st) { + con_st = new_st; + switch(con_st) { + case CONSOLE_STATE_DEFAULT: + fprintf(stream, ANSI_COLOR_RESET); + return; + case CONSOLE_STATE_PROMPT: + fprintf(stream, ANSI_COLOR_YELLOW); + return; + case CONSOLE_STATE_USER_INPUT: + fprintf(stream, ANSI_BOLD ANSI_COLOR_GREEN); + return; + } + } +} + +static bool is_interacting = false; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +void sigint_handler(int signo) { + set_console_state(stdout, CONSOLE_STATE_DEFAULT); + printf("\n"); // this also force flush stdout. + if (signo == SIGINT) { + if (!is_interacting) { + is_interacting=true; + } else { + _exit(130); + } + } +} +#endif + +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream) { + + if (params.seed <= 0) { + params.seed = time(NULL); + } + + fprintf(errstream, "%s: seed = %d\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + // save choice to use color for later + // (note for later: this is a slightly awkward choice) + con_use_color = params.use_color; + +// params.prompt = R"(// this function checks if the number n is prime +//bool is_prime(int n) {)"; + + // determine the required inference memory per token: + // TODO: better way to do that + { + const std::vector tmp = { 0, 1, 2, 3 }; + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + } + + int n_past = 0; + + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + + // tokenize the prompt + auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + + const int n_ctx = llama_n_ctx(ctx); + + params.n_predict = std::min(params.n_predict, n_ctx - (int) embd_inp.size()); + + // prefix & suffix for instruct mode + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); + + // in instruct mode, we inject a prefix and a suffix to each input by the user + if (params.instruct) { + params.interactive = true; + params.antiprompt.push_back("### Instruction:\n\n"); + } + + // enable interactive mode if reverse prompt is specified + if (params.antiprompt.size() != 0) { + params.interactive = true; + } + + if (params.interactive_start) { + params.interactive = true; + } + + fprintf(errstream, "\n"); + fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + } + fprintf(errstream, "\n"); + if (params.interactive) { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + signal(SIGINT, sigint_handler); +#endif + + fprintf(errstream, "%s: interactive mode on.\n", __func__); + + if(params.antiprompt.size()) { + for (auto antiprompt : params.antiprompt) { + fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str()); + } + } + } + fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(errstream, "\n\n"); + + std::vector embd; + + int last_n_size = params.repeat_last_n; + std::vector last_n_tokens(last_n_size); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + + if (params.interactive) { + fprintf(errstream, "== Running in interactive mode. ==\n" +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) + " - Press Ctrl+C to interject at any time.\n" +#endif + " - Press Return to return control to LLaMa.\n" + " - If you want to submit another line, end your input in '\\'.\n\n"); + is_interacting = params.interactive_start; + } + + int input_consumed = 0; + bool input_noecho = false; + + int remaining_tokens = params.n_predict; + +#if defined (_WIN32) + if (params.use_color) { + // Enable ANSI colors on Windows 10+ + unsigned long dwMode = 0; + void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) + if (hConOut && hConOut != (void*)-1 && GetConsoleMode(hConOut, &dwMode) && !(dwMode & 0x4)) { + SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + } + } +#endif + // the first thing we will do is to output the prompt, so set color accordingly + set_console_state(outstream, CONSOLE_STATE_PROMPT); + + while (remaining_tokens > 0 || params.interactive) { + // predict + if (embd.size() > 0) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { + fprintf(errstream, "%s : failed to eval\n", __func__); + return 1; + } + } + + n_past += embd.size(); + embd.clear(); + + if ((int) embd_inp.size() <= input_consumed) { + // out of user input, sample next token + const float top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; + + llama_token id = 0; + + { + auto logits = llama_get_logits(ctx); + + if (params.ignore_eos) { + // set the logit of the eos token to zero to avoid sampling it + //logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; + // TODO: this does not work of params.logits_all == true + assert(params.perplexity == false); + logits[llama_token_eos()] = 0; + } + + id = llama_sample_top_p_top_k(ctx, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + } + + // add it to the context + embd.push_back(id); + + // echo this to console + input_noecho = false; + + // decrement remaining sampling budget + --remaining_tokens; + } else { + // some user input remains from prompt or interaction, forward it to processing + while ((int) embd_inp.size() > input_consumed) { + embd.push_back(embd_inp[input_consumed]); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[input_consumed]); + ++input_consumed; + if ((int) embd.size() >= params.n_batch) { + break; + } + } + } + + // display text + if (!input_noecho) { + for (auto id : embd) { + fprintf(outstream, "%s", llama_token_to_str(ctx, id)); + } + fflush(outstream); + } + // reset color to default if we there is no pending user input + if (!input_noecho && (int)embd_inp.size() == input_consumed) { + set_console_state(outstream, CONSOLE_STATE_DEFAULT); + } + + // in interactive mode, and not currently processing queued inputs; + // check if we should prompt the user for more + if (params.interactive && (int) embd_inp.size() <= input_consumed) { + // check for reverse prompt + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + + // Check if each of the reverse prompts appears at the end of the output. + for (std::string antiprompt : params.antiprompt) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { + is_interacting = true; + break; + } + } + if (is_interacting) { + // potentially set color to indicate we are taking user input + set_console_state(outstream, CONSOLE_STATE_USER_INPUT); + + if (params.instruct) { + input_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + + fprintf(outstream, "\n> "); + } + + std::string buffer; + std::string line; + bool another_line = true; + do { + std::getline(instream, line); + if (line.empty() || line.back() != '\\') { + another_line = false; + } else { + line.pop_back(); // Remove the continue character + } + buffer += line + '\n'; // Append the line to the result + } while (another_line); + + // done taking input, reset color + set_console_state(outstream, CONSOLE_STATE_DEFAULT); + + auto line_inp = ::llama_tokenize(ctx, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + remaining_tokens -= line_inp.size(); + + input_noecho = true; // do not echo this again + } + is_interacting = false; + } + + // end of text token + if (embd.back() == llama_token_eos()) { + if (params.interactive) { + is_interacting = true; + } else { + fprintf(errstream, " [end of text]\n"); + break; + } + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + if (params.interactive && remaining_tokens <= 0) { + remaining_tokens = params.n_predict; + is_interacting = true; + } + } + +#if defined (_WIN32) + signal(SIGINT, SIG_DFL); +#endif + + llama_print_timings(ctx); + + llama_free(ctx); + + set_console_state(outstream, CONSOLE_STATE_DEFAULT); + + return 0; +} diff --git a/run.h b/run.h new file mode 100644 index 0000000000000..39c8e9f063dc1 --- /dev/null +++ b/run.h @@ -0,0 +1,10 @@ +#pragma once + +#include "llama.h" +#include "utils.h" + +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream); diff --git a/tcp_server.cpp b/tcp_server.cpp new file mode 100644 index 0000000000000..9077c1807de1a --- /dev/null +++ b/tcp_server.cpp @@ -0,0 +1,245 @@ +#include "tcp_server.h" +#include "llama.h" +#include "utils.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +class PosixStream : public std::istream { + public: + PosixStream(int fd) : std::istream(&buf), buf(fd) {} + ~PosixStream() { close(buf.get_fd()); } + + private: + class PosixStreamBuf : public std::streambuf { + public: + PosixStreamBuf(int fd) : fd(fd) {} + int get_fd() const { return fd; } + + protected: + virtual int_type underflow() { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE); + if (num_read <= 0) { + return traits_type::eof(); + } + + setg(buffer, buffer, buffer + num_read); + return traits_type::to_int_type(*gptr()); + } + + private: + static const int BUFFER_SIZE = 1024; + int fd; + char buffer[BUFFER_SIZE]; + }; + + PosixStreamBuf buf; +}; + +void die(const char *msg, ...) +{ + va_list ap; + + va_start(ap, msg); + vfprintf(stderr, msg, ap); + va_end(ap); + fputc('\n', stderr); + exit(1); +} + +static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) { + bool done = false; + uint8_t *buf = *param_buf; + size_t bufsize = *param_buf_size; + size_t bufpos = 0; + while (!done) { + if (bufpos == bufsize) { + bufsize += 1024; + buf = (uint8_t *)realloc(buf, bufsize); + if (!buf) { + die("failed to allocate memory"); + } + } + + int c = fgetc(instream); + if (c == EOF) { + die("unexpected EOF client socket"); + } + buf[bufpos++] = (uint8_t)c; + if (c == 0) { + // done reading argument + break; + } + } + *param_buf = buf; + *param_buf_size = bufsize; + return strdup((char *)buf); +} + +static int read_arguments(int argc, char **argv, FILE *instream) { + int i = 1; + size_t param_buf_size = 0; + uint8_t *param_buf = nullptr; + + for (i = 1; i < argc; i++) { + argv[i] = read_argument(¶m_buf, ¶m_buf_size, instream); + } + + free(param_buf); + return i; +} + +static int serve_model(llama_context * ctx, + gpt_params params, + int sock_fd) +{ + int argc; + char **argv; + FILE *instream = fdopen(sock_fd, "r"); + FILE *outstream = fdopen(sock_fd, "w"); + setvbuf(instream, NULL, _IONBF, 0); + + // start by reading the parameter count + if (fscanf(instream, "%d\n", &argc) != 1) { + fprintf(outstream, "Error: First line must be character count\n"); + fflush(outstream); + return 1; + } + + argc += 1; // add one extra argument to emulate the program command line + argv = (char **)malloc(argc * sizeof *argv); + argv[0] = nullptr; + if (read_arguments(argc, argv, instream) != argc) { + fprintf(outstream, "Error: Failed to read arguments\n"); + fflush(outstream); + } + + if (gpt_params_parse(argc, argv, params) == false) { + fprintf(outstream, "Error: Failed to parse parameters\n"); + fflush(outstream); + return 1; + } + + for (int i = 1; i < argc; i++) { + free(argv[i]); + } + free(argv); + + PosixStream tcp_instream(sock_fd); + + return run(ctx, params, tcp_instream, outstream, outstream); +} + +int listen_tcp(llama_context * ctx, gpt_params params) { + int listen_fd; + int status; + pid_t child; + struct addrinfo hints; + struct addrinfo *servinfo, *p; + int yes = 1; + + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + // This should only ever listen on a loopback address. Access from outside + // should be proxied via socat or similar software + status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo); + if (status) { + die("getaddrinfo error: %s", gai_strerror(status)); + } + + // bind to the first addrinfo we can from the getaddrinfo results + for (p = servinfo; p != NULL; p = p->ai_next) { + listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (listen_fd == -1) { + perror("server: socket"); + continue; + } + + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) { + die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno)); + } + + if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) { + struct sockaddr_in addr_in; + socklen_t addr_in_len = sizeof(addr_in); + memset(&addr_in, 0, addr_in_len); + getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len); + + printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port)); + break; + } + + close(listen_fd); + perror("server: bind"); + } + + freeaddrinfo(servinfo); + + if (p == NULL) { + die("failed to bind: %s", strerror(errno)); + } + + if (listen(listen_fd, 20)) { + die("listen error: %s", strerror(errno)); + } + // Don't track child processes, so ignore SIGCHLD to prevent zombies + signal(SIGCHLD, SIG_IGN); + + for (;;) { + struct sockaddr_in client_addr; + socklen_t client_addr_len = 0; + memset(&client_addr, 0, sizeof(client_addr)); + + int sock_fd = accept(listen_fd, + (struct sockaddr *)&client_addr, + &client_addr_len); + if (sock_fd < 0) { + fprintf(stderr, "accept error: %s\n", strerror(errno)); + break; + } + + child = fork(); + if (child == 0) { + // close the listen_fd since we won't use it in the child + close(listen_fd); + int ret = serve_model(ctx, params, sock_fd); + close(sock_fd); + return ret; + } else { + // close the client since we won't use it in the server + close(sock_fd); + sock_fd = 0; + } + } + close(listen_fd); + + // ignore SIGTERM since we'll send it to the group + signal(SIGTERM, SIG_IGN); + // tell children to exit + kill(0, SIGTERM); + // wait for children to terminate + wait(&status); + return 0; +} diff --git a/tcp_server.h b/tcp_server.h new file mode 100644 index 0000000000000..38d6ecc810026 --- /dev/null +++ b/tcp_server.h @@ -0,0 +1,7 @@ +#pragma once + +#include "utils.h" +#include "llama.h" +#include "run.h" + +int listen_tcp(llama_context * ctx, gpt_params params); diff --git a/utils.cpp b/utils.cpp index 1d5309c3a4ca3..78baf924c4b87 100644 --- a/utils.cpp +++ b/utils.cpp @@ -77,6 +77,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.ignore_eos = true; } else if (arg == "--n_parts") { params.n_parts = std::stoi(argv[++i]); +#ifndef _WIN32 + } else if (arg == "-l" || arg == "--listen") { + params.listen_port = argv[++i]; +#endif } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); @@ -125,6 +129,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); +#ifndef _WIN32 + fprintf(stderr, " -l PORT, --listen PORT\n"); + fprintf(stderr, " Run in TCP mode, listening on PORT\n"); +#endif fprintf(stderr, "\n"); } diff --git a/utils.h b/utils.h index b0de556c95370..487892b1258c2 100644 --- a/utils.h +++ b/utils.h @@ -42,6 +42,10 @@ struct gpt_params { bool instruct = false; // instruction mode (used for Alpaca models) bool ignore_eos = false; // do not stop generating after eos bool perplexity = false; // compute perplexity over the prompt + +#ifndef _WIN32 + std::string listen_port = ""; // TCP port for when running in server mode +#endif }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params);