@@ -92,6 +92,7 @@ enum server_task_type {
92
92
enum server_task_cmpl_type {
93
93
SERVER_TASK_CMPL_TYPE_NORMAL,
94
94
SERVER_TASK_CMPL_TYPE_EMBEDDING,
95
+ SERVER_TASK_CMPL_TYPE_RERANK,
95
96
SERVER_TASK_CMPL_TYPE_INFILL,
96
97
};
97
98
@@ -172,6 +173,7 @@ struct server_slot {
172
173
std::vector<completion_token_output> generated_token_probs;
173
174
174
175
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176
+
175
177
bool has_next_token = true ;
176
178
bool truncated = false ;
177
179
bool stopped_eos = false ;
@@ -954,8 +956,17 @@ struct server_context {
954
956
slot.prompt = *prompt;
955
957
} else if (prompt->is_array () && prompt->size () == 1 && prompt->at (0 ).is_array ()) {
956
958
slot.prompt = prompt->at (0 );
959
+ } else if (prompt->is_array () && prompt->size () > 1 ) {
960
+ // array of strings
961
+ for (const auto & el : *prompt) {
962
+ if (!el.is_string ()) {
963
+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
964
+ return false ;
965
+ }
966
+ }
967
+ slot.prompt = *prompt;
957
968
} else {
958
- send_error (task, " \" prompt\" must be a string or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
969
+ send_error (task, " \" prompt\" must be a string, an array of strings or an array of integers" , ERROR_TYPE_INVALID_REQUEST);
959
970
return false ;
960
971
}
961
972
}
@@ -1389,6 +1400,7 @@ struct server_context {
1389
1400
1390
1401
res.data = json {
1391
1402
{" embedding" , std::vector<float >(n_embd, 0 .0f )},
1403
+ {" index" , slot.index },
1392
1404
};
1393
1405
1394
1406
continue ;
@@ -1407,6 +1419,44 @@ struct server_context {
1407
1419
queue_results.send (res);
1408
1420
}
1409
1421
1422
+ void send_rank (const server_slot & slot, const llama_batch & batch) {
1423
+ server_task_result res;
1424
+ res.id = slot.id_task ;
1425
+ res.error = false ;
1426
+ res.stop = true ;
1427
+
1428
+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1429
+ if (!batch.logits [i] || batch.seq_id [i][0 ] != slot.id + 1 ) {
1430
+ continue ;
1431
+ }
1432
+
1433
+ const float * embd = llama_get_embeddings_seq (ctx, batch.seq_id [i][0 ]);
1434
+ if (embd == NULL ) {
1435
+ embd = llama_get_embeddings_ith (ctx, i);
1436
+ }
1437
+
1438
+ if (embd == NULL ) {
1439
+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , batch.token [i], batch.seq_id [i][0 ]);
1440
+
1441
+ res.data = json {
1442
+ {" index" , slot.index },
1443
+ {" rank" , -1e6 },
1444
+ };
1445
+
1446
+ continue ;
1447
+ }
1448
+
1449
+ res.data = json {
1450
+ {" index" , slot.index },
1451
+ {" rank" , embd[0 ]},
1452
+ };
1453
+ }
1454
+
1455
+ SLT_DBG (slot, " sending rank, res = '%s'\n " , res.data .dump ().c_str ());
1456
+
1457
+ queue_results.send (res);
1458
+ }
1459
+
1410
1460
//
1411
1461
// Functions to create new task(s) and receive result(s)
1412
1462
//
@@ -1442,13 +1492,23 @@ struct server_context {
1442
1492
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1443
1493
else if (prompt.is_array ()) {
1444
1494
std::vector<json> prompts = prompt;
1445
- for (size_t i = 0 ; i < prompts.size (); i++) {
1446
- const auto & e = prompts[i];
1447
- if (e.is_string () || json_is_array_of_numbers (e)) {
1448
- data[" index" ] = i;
1449
- create_task (data, true , e);
1450
- } else {
1451
- throw std::runtime_error (error_msg);
1495
+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496
+ for (size_t i = 1 ; i < prompts.size (); i++) {
1497
+ json qd;
1498
+ qd.push_back (prompts[0 ]);
1499
+ qd.push_back (prompts[i]);
1500
+ data[" index" ] = i - 1 ;
1501
+ create_task (data, true , qd);
1502
+ }
1503
+ } else {
1504
+ for (size_t i = 0 ; i < prompts.size (); i++) {
1505
+ const auto & e = prompts[i];
1506
+ if (e.is_string () || json_is_array_of_numbers (e)) {
1507
+ data[" index" ] = i;
1508
+ create_task (data, true , e);
1509
+ } else {
1510
+ throw std::runtime_error (error_msg);
1511
+ }
1452
1512
}
1453
1513
}
1454
1514
}
@@ -1492,7 +1552,9 @@ struct server_context {
1492
1552
return ;
1493
1553
}
1494
1554
1495
- size_t idx = result.data [" index" ];
1555
+ const size_t idx = result.data [" index" ];
1556
+ GGML_ASSERT (idx < results.size () && " index out of range" );
1557
+
1496
1558
results[idx] = result;
1497
1559
}
1498
1560
result_handler (results);
@@ -1951,6 +2013,29 @@ struct server_context {
1951
2013
}
1952
2014
1953
2015
prompt_tokens = embd_inp;
2016
+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2017
+ // require slot.prompt to be array of 2 strings
2018
+ if (!slot.prompt .is_array () || slot.prompt .size () != 2 ) {
2019
+ SLT_ERR (slot, " %s" , " invalid prompt for rerank task\n " );
2020
+ slot.release ();
2021
+ send_error (slot, " invalid prompt for rerank task" , ERROR_TYPE_INVALID_REQUEST);
2022
+ continue ;
2023
+ }
2024
+
2025
+ // prompt: <s>query</s><s>doc</s>
2026
+ prompt_tokens.clear ();
2027
+ prompt_tokens.push_back (llama_token_bos (model));
2028
+ {
2029
+ const auto part = tokenize (slot.prompt [0 ], false );
2030
+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2031
+ }
2032
+ prompt_tokens.push_back (llama_token_eos (model));
2033
+ prompt_tokens.push_back (llama_token_bos (model));
2034
+ {
2035
+ const auto part = tokenize (slot.prompt [1 ], false );
2036
+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2037
+ }
2038
+ prompt_tokens.push_back (llama_token_eos (model));
1954
2039
} else {
1955
2040
prompt_tokens = tokenize (slot.prompt , system_prompt.empty ()); // add BOS if there isn't system prompt
1956
2041
}
@@ -1970,7 +2055,7 @@ struct server_context {
1970
2055
continue ;
1971
2056
}
1972
2057
1973
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2058
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
1974
2059
// this prompt is too large to process - discard it
1975
2060
if (slot.n_prompt_tokens > n_ubatch) {
1976
2061
slot.release ();
@@ -2048,15 +2133,18 @@ struct server_context {
2048
2133
slot.n_prompt_tokens_processed = 0 ;
2049
2134
}
2050
2135
2051
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2136
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot. cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ) {
2052
2137
// cannot fit the prompt in the current batch - will try next iter
2053
2138
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2054
2139
continue ;
2055
2140
}
2056
2141
}
2057
2142
2058
2143
// check that we are in the right batch_type, if not defer the slot
2059
- bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0 ;
2144
+ const bool slot_type =
2145
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2146
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0 ;
2147
+
2060
2148
if (batch_type == -1 ) {
2061
2149
batch_type = slot_type;
2062
2150
} else if (batch_type != slot_type) {
@@ -2229,6 +2317,13 @@ struct server_context {
2229
2317
continue ; // continue loop of slots
2230
2318
}
2231
2319
2320
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2321
+ send_rank (slot, batch_view);
2322
+ slot.release ();
2323
+ slot.i_batch = -1 ;
2324
+ continue ; // continue loop of slots
2325
+ }
2326
+
2232
2327
// prompt evaluated for next-token prediction
2233
2328
slot.state = SLOT_STATE_GENERATING;
2234
2329
} else if (slot.state != SLOT_STATE_GENERATING) {
@@ -3023,6 +3118,82 @@ int main(int argc, char ** argv) {
3023
3118
res_ok (res, root);
3024
3119
};
3025
3120
3121
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3122
+ const json body = json::parse (req.body );
3123
+
3124
+ // TODO: implement
3125
+ // int top_n = 1;
3126
+ // if (body.count("top_n") != 1) {
3127
+ // top_n = body.at("top_n");
3128
+ // } else {
3129
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3130
+ // return;
3131
+ // }
3132
+
3133
+ json query;
3134
+ if (body.count (" query" ) == 1 ) {
3135
+ query = body.at (" query" );
3136
+ if (!query.is_string ()) {
3137
+ res_error (res, format_error_response (" \" query\" must be a string" , ERROR_TYPE_INVALID_REQUEST));
3138
+ return ;
3139
+ }
3140
+ } else {
3141
+ exit (0 );
3142
+ res_error (res, format_error_response (" \" query\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3143
+ return ;
3144
+ }
3145
+
3146
+ json documents;
3147
+ if (body.count (" documents" ) != 0 ) {
3148
+ documents = body.at (" documents" );
3149
+ if (!documents.is_array () || documents.size () == 0 ) {
3150
+ res_error (res, format_error_response (" \" documents\" must be a non-empty string array" , ERROR_TYPE_INVALID_REQUEST));
3151
+ return ;
3152
+ }
3153
+ } else {
3154
+ res_error (res, format_error_response (" \" documents\" must be provided" , ERROR_TYPE_INVALID_REQUEST));
3155
+ return ;
3156
+ }
3157
+
3158
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
3159
+ json prompt;
3160
+ prompt.push_back (query);
3161
+ for (const auto & doc : documents) {
3162
+ prompt.push_back (doc);
3163
+ }
3164
+
3165
+ LOG_DBG (" rerank prompt: %s\n " , prompt.dump ().c_str ());
3166
+
3167
+ // create and queue the task
3168
+ json responses = json::array ();
3169
+ bool error = false ;
3170
+ {
3171
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl ({{" prompt" , prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3172
+ ctx_server.queue_results .add_waiting_tasks (tasks);
3173
+ ctx_server.queue_tasks .post (tasks);
3174
+
3175
+ // get the result
3176
+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
3177
+
3178
+ ctx_server.receive_cmpl_results (task_ids, [&](std::vector<server_task_result> & results) {
3179
+ for (const auto & res : results) {
3180
+ responses.push_back (res.data );
3181
+ }
3182
+ }, [&](const json & error_data) {
3183
+ res_error (res, error_data);
3184
+ error = true ;
3185
+ });
3186
+ }
3187
+
3188
+ if (error) {
3189
+ return ;
3190
+ }
3191
+
3192
+ // write JSON response
3193
+ json root = format_response_rerank (body, responses);
3194
+ res_ok (res, root);
3195
+ };
3196
+
3026
3197
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
3027
3198
json result = json::array ();
3028
3199
for (size_t i = 0 ; i < ctx_server.loras .size (); ++i) {
@@ -3119,6 +3290,7 @@ int main(int argc, char ** argv) {
3119
3290
svr->Post (" /embedding" , handle_embeddings); // legacy
3120
3291
svr->Post (" /embeddings" , handle_embeddings);
3121
3292
svr->Post (" /v1/embeddings" , handle_embeddings);
3293
+ svr->Post (" /v1/rerank" , handle_rerank);
3122
3294
svr->Post (" /tokenize" , handle_tokenize);
3123
3295
svr->Post (" /detokenize" , handle_detokenize);
3124
3296
// LoRA adapters hotswap
0 commit comments