Skip to content

Commit 0133d3d

Browse files
Azeirahhodlen
authored andcommitted
server : implement credentialed CORS (ggml-org#4514)
* Implement credentialed CORS according to MDN * Fix syntax error * Move validate_api_key up so it is defined before its first usage
1 parent 8c7a827 commit 0133d3d

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

examples/server/server.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,9 +2822,15 @@ int main(int argc, char **argv)
28222822

28232823
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
28242824

2825-
svr.set_default_headers({{"Server", "llama.cpp"},
2826-
{"Access-Control-Allow-Origin", "*"},
2827-
{"Access-Control-Allow-Headers", "content-type"}});
2825+
svr.set_default_headers({{"Server", "llama.cpp"}});
2826+
2827+
// CORS preflight
2828+
svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) {
2829+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2830+
res.set_header("Access-Control-Allow-Credentials", "true");
2831+
res.set_header("Access-Control-Allow-Methods", "POST");
2832+
res.set_header("Access-Control-Allow-Headers", "*");
2833+
});
28282834

28292835
svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
28302836
server_state current_state = state.load();
@@ -2987,9 +2993,9 @@ int main(int argc, char **argv)
29872993
return false;
29882994
});
29892995

2990-
svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res)
2996+
svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
29912997
{
2992-
res.set_header("Access-Control-Allow-Origin", "*");
2998+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
29932999
json data = {
29943000
{ "user_name", llama.name_user.c_str() },
29953001
{ "assistant_name", llama.name_assistant.c_str() }
@@ -2999,6 +3005,7 @@ int main(int argc, char **argv)
29993005

30003006
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
30013007
{
3008+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30023009
if (!validate_api_key(req, res)) {
30033010
return;
30043011
}
@@ -3066,8 +3073,9 @@ int main(int argc, char **argv)
30663073
}
30673074
});
30683075

3069-
svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
3076+
svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
30703077
{
3078+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30713079
std::time_t t = std::time(0);
30723080

30733081
json models = {
@@ -3085,9 +3093,11 @@ int main(int argc, char **argv)
30853093
res.set_content(models.dump(), "application/json; charset=utf-8");
30863094
});
30873095

3096+
30883097
// TODO: add mount point without "/v1" prefix -- how?
30893098
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
30903099
{
3100+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
30913101
if (!validate_api_key(req, res)) {
30923102
return;
30933103
}
@@ -3161,6 +3171,7 @@ int main(int argc, char **argv)
31613171

31623172
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
31633173
{
3174+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
31643175
if (!validate_api_key(req, res)) {
31653176
return;
31663177
}
@@ -3233,6 +3244,7 @@ int main(int argc, char **argv)
32333244

32343245
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
32353246
{
3247+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32363248
const json body = json::parse(req.body);
32373249
std::vector<llama_token> tokens;
32383250
if (body.count("content") != 0)
@@ -3245,6 +3257,7 @@ int main(int argc, char **argv)
32453257

32463258
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
32473259
{
3260+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32483261
const json body = json::parse(req.body);
32493262
std::string content;
32503263
if (body.count("tokens") != 0)
@@ -3259,6 +3272,7 @@ int main(int argc, char **argv)
32593272

32603273
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
32613274
{
3275+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
32623276
const json body = json::parse(req.body);
32633277
json prompt;
32643278
if (body.count("content") != 0)

0 commit comments

Comments
 (0)