Skip to content

server: SSL Support #5926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE
MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif

ifdef LLAMA_SERVER_SSL
MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT
MK_LDFLAGS += -lssl -lcrypto
endif

ifdef LLAMA_CODE_COVERAGE
MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase ''
Expand Down
6 changes: 6 additions & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL)
find_package(OpenSSL REQUIRED)
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT)
endif()
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
Expand Down
26 changes: 26 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437
- `--log-disable`: Output logs to stdout only, default: enabled.
- `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json)

**If compiled with `LLAMA_SERVER_SSL=ON`**
- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate

## Build

server is build alongside everything else from the root of the project
Expand All @@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project
cmake --build . --config Release
```

## Build with SSL

server can also be built with SSL support using OpenSSL 3

- Using `make`:

```bash
# NOTE: For non-system openssl, use the following:
# CXXFLAGS="-I /path/to/openssl/include"
# LDFLAGS="-L /path/to/openssl/lib"
make LLAMA_SERVER_SSL=true server
```

- Using `CMake`:

```bash
mkdir build
cd build
cmake .. -DLLAMA_SERVER_SSL=ON
make server
```

## Quick Start

To get started right away, run the following command, making sure to use the correct path for the model you have:
Expand Down
108 changes: 74 additions & 34 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <mutex>
#include <thread>
#include <signal.h>
#include <memory>

using json = nlohmann::json;

Expand Down Expand Up @@ -118,6 +119,11 @@ struct server_params {

std::vector<std::string> api_keys;

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
std::string ssl_key_file = "";
std::string ssl_cert_file = "";
#endif

bool slots_endpoint = true;
bool metrics_endpoint = false;
};
Expand Down Expand Up @@ -2095,6 +2101,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n");
printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n");
#endif
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
Expand Down Expand Up @@ -2173,7 +2183,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
}
}
key_file.close();
} else if (arg == "--timeout" || arg == "-to") {

}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of hate that this breaks the brace formatting of the rest of the if/else if chain, but I figured it was better than duplicating the else if (arg == "--timeout"....

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
else if (arg == "--ssl-key-file") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.ssl_key_file = argv[i];
} else if (arg == "--ssl-cert-file") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.ssl_cert_file = argv[i];
}
#endif
else if (arg == "--timeout" || arg == "-to") {
if (++i >= argc) {
invalid_param = true;
break;
Expand Down Expand Up @@ -2611,21 +2638,34 @@ int main(int argc, char ** argv) {
{"system_info", llama_print_system_info()},
});

httplib::Server svr;
std::unique_ptr<httplib::Server> svr;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
svr.reset(
new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
);
} else {
LOG_INFO("Running without SSL", {});
svr.reset(new httplib::Server());
}
#else
svr.reset(new httplib::Server());
#endif

std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};

svr.set_default_headers({{"Server", "llama.cpp"}});
svr->set_default_headers({{"Server", "llama.cpp"}});

// CORS preflight
svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*");
});

svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
switch (current_state) {
case SERVER_STATE_READY:
Expand Down Expand Up @@ -2681,7 +2721,7 @@ int main(int argc, char ** argv) {
});

if (sparams.slots_endpoint) {
svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
Expand All @@ -2702,7 +2742,7 @@ int main(int argc, char ** argv) {
}

if (sparams.metrics_endpoint) {
svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) {
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
Expand Down Expand Up @@ -2787,9 +2827,9 @@ int main(int argc, char ** argv) {
});
}

svr.set_logger(log_server_request);
svr->set_logger(log_server_request);

svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";

char buf[BUFSIZ];
Expand All @@ -2805,7 +2845,7 @@ int main(int argc, char ** argv) {
res.status = 500;
});

svr.set_error_handler([](const httplib::Request &, httplib::Response & res) {
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 401) {
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
Expand All @@ -2818,16 +2858,16 @@ int main(int argc, char ** argv) {
});

// set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
svr->set_read_timeout (sparams.read_timeout);
svr->set_write_timeout(sparams.write_timeout);

if (!svr.bind_to_port(sparams.hostname, sparams.port)) {
if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}

// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
svr->set_base_dir(sparams.public_path);

std::unordered_map<std::string, std::string> log_data;

Expand Down Expand Up @@ -2888,30 +2928,30 @@ int main(int argc, char ** argv) {
};

// this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response & res) {
svr->Get("/", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&index_html), index_html_len, "text/html; charset=utf-8");
return false;
});

// this is only called if no index.js is found in the public --path
svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char *>(&index_js), index_js_len, "text/javascript; charset=utf-8");
return false;
});

// this is only called if no index.html is found in the public --path
svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&completion_js), completion_js_len, "application/javascript; charset=utf-8");
return false;
});

// this is only called if no index.html is found in the public --path
svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8");
return false;
});

svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", ctx_server.name_user.c_str() },
Expand Down Expand Up @@ -3003,11 +3043,11 @@ int main(int argc, char ** argv) {
}
};

svr.Post("/completion", completions); // legacy
svr.Post("/completions", completions);
svr.Post("/v1/completions", completions);
svr->Post("/completion", completions); // legacy
svr->Post("/completions", completions);
svr->Post("/v1/completions", completions);

svr.Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
svr->Get("/v1/models", [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json models = {
Expand Down Expand Up @@ -3102,10 +3142,10 @@ int main(int argc, char ** argv) {
}
};

svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions);
svr->Post("/chat/completions", chat_completions);
svr->Post("/v1/chat/completions", chat_completions);

svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
Expand Down Expand Up @@ -3169,11 +3209,11 @@ int main(int argc, char ** argv) {
}
});

svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
return res.set_content("", "application/json; charset=utf-8");
});

svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);

Expand All @@ -3185,7 +3225,7 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
});

svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);

Expand All @@ -3199,7 +3239,7 @@ int main(int argc, char ** argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
});

svr.Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
svr->Post("/embedding", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
Expand Down Expand Up @@ -3230,7 +3270,7 @@ int main(int argc, char ** argv) {
return res.set_content(result.data.dump(), "application/json; charset=utf-8");
});

svr.Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
svr->Post("/v1/embeddings", [&params, &ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!params.embedding) {
res.status = 501;
Expand Down Expand Up @@ -3301,13 +3341,13 @@ int main(int argc, char ** argv) {
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
}
log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };

LOG_INFO("HTTP server listening", log_data);

// run the HTTP server in a thread - see comment below
std::thread t([&]() {
if (!svr.listen_after_bind()) {
if (!svr->listen_after_bind()) {
state.store(SERVER_STATE_ERROR);
return 1;
}
Expand Down Expand Up @@ -3348,7 +3388,7 @@ int main(int argc, char ** argv) {

ctx_server.queue_tasks.start_loop();

svr.stop();
svr->stop();
t.join();

llama_backend_free();
Expand Down