@@ -729,24 +729,30 @@ struct server_task_result_embd : server_task_result {
729729 int index = 0 ;
730730 std::vector<std::vector<float >> embedding;
731731
732+ // OAI-compat fields
733+ bool oaicompat = false ;
734+
732735 virtual int get_index () override {
733736 return index;
734737 }
735738
736739 virtual json to_json () override {
737- if (embedding.size () == 1 ) {
738- // to be OAI compatible
739- return json {
740- {" index" , index},
741- {" embedding" , embedding[0 ]},
742- };
743- }
740+ return oaicompat ? to_json_oaicompat () : to_json_non_oaicompat ();
741+ }
744742
743+ json to_json_non_oaicompat () {
745744 return json {
746745 {" index" , index},
747746 {" embedding" , embedding},
748747 };
749748 }
749+
750+ json to_json_oaicompat () {
751+ return json {
752+ {" index" , index},
753+ {" embedding" , embedding[0 ]},
754+ };
755+ }
750756};
751757
752758struct server_task_result_rerank : server_task_result {
@@ -2018,8 +2024,9 @@ struct server_context {
20182024
20192025 void send_embedding (const server_slot & slot, const llama_batch & batch) {
20202026 auto res = std::make_unique<server_task_result_embd>();
2021- res->id = slot.id_task ;
2022- res->index = slot.index ;
2027+ res->id = slot.id_task ;
2028+ res->index = slot.index ;
2029+ res->oaicompat = slot.params .oaicompat ;
20232030
20242031 const int n_embd = llama_n_embd (model);
20252032
@@ -3667,14 +3674,17 @@ int main(int argc, char ** argv) {
36673674 res_ok (res, data);
36683675 };
36693676
3670- const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3677+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat ) {
36713678 const json body = json::parse (req.body );
3672- bool oaicompat = false ;
3679+
3680+ if (oaicompat && llama_pooling_type (ctx_server.ctx ) == LLAMA_POOLING_TYPE_NONE) {
3681+ res_error (res, format_error_response (" Pooling type 'none' is not OAI compatible. Please use a different pooling type" , ERROR_TYPE_INVALID_REQUEST));
3682+ return ;
3683+ }
36733684
36743685 // an input prompt can be a string or a list of tokens (integer)
36753686 json prompt;
36763687 if (body.count (" input" ) != 0 ) {
3677- oaicompat = true ;
36783688 prompt = body.at (" input" );
36793689 } else if (body.count (" content" ) != 0 ) {
36803690 // with "content", we only support single prompt
@@ -3691,10 +3701,15 @@ int main(int argc, char ** argv) {
36913701 std::vector<server_task> tasks;
36923702 std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server.ctx , prompt, /* add_special */ false , true );
36933703 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
3694- server_task task = server_task (SERVER_TASK_TYPE_EMBEDDING);
3704+ server_task task = server_task (SERVER_TASK_TYPE_EMBEDDING);
3705+
36953706 task.id = ctx_server.queue_tasks .get_new_id ();
36963707 task.index = i;
36973708 task.prompt_tokens = std::move (tokenized_prompts[i]);
3709+
3710+ // OAI-compat
3711+ task.params .oaicompat = oaicompat;;
3712+
36983713 tasks.push_back (task);
36993714 }
37003715
@@ -3722,12 +3737,18 @@ int main(int argc, char ** argv) {
37223737 }
37233738
37243739 // write JSON response
3725- json root = oaicompat
3726- ? format_embeddings_response_oaicompat (body, responses)
3727- : responses.size () == 1 ? responses[0 ] : json (responses);
3740+ json root = oaicompat ? format_embeddings_response_oaicompat (body, responses) : json (responses);
37283741 res_ok (res, root);
37293742 };
37303743
3744+ const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3745+ handle_embeddings_impl (req, res, false );
3746+ };
3747+
3748+ const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3749+ handle_embeddings_impl (req, res, true );
3750+ };
3751+
37313752 const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
37323753 if (!ctx_server.params_base .reranking || ctx_server.params_base .embedding ) {
37333754 res_error (res, format_error_response (" This server does not support reranking. Start it with `--reranking` and without `--embedding`" , ERROR_TYPE_NOT_SUPPORTED));
@@ -3901,7 +3922,7 @@ int main(int argc, char ** argv) {
39013922 svr->Post (" /infill" , handle_infill);
39023923 svr->Post (" /embedding" , handle_embeddings); // legacy
39033924 svr->Post (" /embeddings" , handle_embeddings);
3904- svr->Post (" /v1/embeddings" , handle_embeddings );
3925+ svr->Post (" /v1/embeddings" , handle_embeddings_oai );
39053926 svr->Post (" /rerank" , handle_rerank);
39063927 svr->Post (" /reranking" , handle_rerank);
39073928 svr->Post (" /v1/rerank" , handle_rerank);
0 commit comments