Skip to content

Commit 6235c62

Browse files
committed
server : add rerank endpoint
ggml-ci
1 parent 125a067 commit 6235c62

File tree

3 files changed

+209
-14
lines changed

3 files changed

+209
-14
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
10931093
}
10941094
).set_sparam());
10951095
add_opt(llama_arg(
1096-
{"--pooling"}, "{none,mean,cls,last}",
1096+
{"--pooling"}, "{none,mean,cls,last, rank}",
10971097
"pooling type for embeddings, use model default if unspecified",
10981098
[](gpt_params & params, const std::string & value) {
10991099
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }

examples/server/server.cpp

Lines changed: 184 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ enum server_task_type {
9292
enum server_task_cmpl_type {
9393
SERVER_TASK_CMPL_TYPE_NORMAL,
9494
SERVER_TASK_CMPL_TYPE_EMBEDDING,
95+
SERVER_TASK_CMPL_TYPE_RERANK,
9596
SERVER_TASK_CMPL_TYPE_INFILL,
9697
};
9798

@@ -172,6 +173,7 @@ struct server_slot {
172173
std::vector<completion_token_output> generated_token_probs;
173174

174175
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176+
175177
bool has_next_token = true;
176178
bool truncated = false;
177179
bool stopped_eos = false;
@@ -954,8 +956,17 @@ struct server_context {
954956
slot.prompt = *prompt;
955957
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
956958
slot.prompt = prompt->at(0);
959+
} else if (prompt->is_array() && prompt->size() > 1) {
960+
// array of strings
961+
for (const auto & el : *prompt) {
962+
if (!el.is_string()) {
963+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
964+
return false;
965+
}
966+
}
967+
slot.prompt = *prompt;
957968
} else {
958-
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
969+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
959970
return false;
960971
}
961972
}
@@ -1389,6 +1400,7 @@ struct server_context {
13891400

13901401
res.data = json {
13911402
{"embedding", std::vector<float>(n_embd, 0.0f)},
1403+
{"index", slot.index},
13921404
};
13931405

13941406
continue;
@@ -1407,6 +1419,44 @@ struct server_context {
14071419
queue_results.send(res);
14081420
}
14091421

1422+
void send_rank(const server_slot & slot, const llama_batch & batch) {
1423+
server_task_result res;
1424+
res.id = slot.id_task;
1425+
res.error = false;
1426+
res.stop = true;
1427+
1428+
for (int i = 0; i < batch.n_tokens; ++i) {
1429+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1430+
continue;
1431+
}
1432+
1433+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1434+
if (embd == NULL) {
1435+
embd = llama_get_embeddings_ith(ctx, i);
1436+
}
1437+
1438+
if (embd == NULL) {
1439+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1440+
1441+
res.data = json {
1442+
{"index", slot.index},
1443+
{"rank", -1e6},
1444+
};
1445+
1446+
continue;
1447+
}
1448+
1449+
res.data = json {
1450+
{"index", slot.index},
1451+
{"rank", embd[0]},
1452+
};
1453+
}
1454+
1455+
SLT_DBG(slot, "sending rank, res = '%s'\n", res.data.dump().c_str());
1456+
1457+
queue_results.send(res);
1458+
}
1459+
14101460
//
14111461
// Functions to create new task(s) and receive result(s)
14121462
//
@@ -1442,13 +1492,23 @@ struct server_context {
14421492
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14431493
else if (prompt.is_array()) {
14441494
std::vector<json> prompts = prompt;
1445-
for (size_t i = 0; i < prompts.size(); i++) {
1446-
const auto & e = prompts[i];
1447-
if (e.is_string() || json_is_array_of_numbers(e)) {
1448-
data["index"] = i;
1449-
create_task(data, true, e);
1450-
} else {
1451-
throw std::runtime_error(error_msg);
1495+
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496+
for (size_t i = 1; i < prompts.size(); i++) {
1497+
json qd;
1498+
qd.push_back(prompts[0]);
1499+
qd.push_back(prompts[i]);
1500+
data["index"] = i - 1;
1501+
create_task(data, true, qd);
1502+
}
1503+
} else {
1504+
for (size_t i = 0; i < prompts.size(); i++) {
1505+
const auto & e = prompts[i];
1506+
if (e.is_string() || json_is_array_of_numbers(e)) {
1507+
data["index"] = i;
1508+
create_task(data, true, e);
1509+
} else {
1510+
throw std::runtime_error(error_msg);
1511+
}
14521512
}
14531513
}
14541514
}
@@ -1492,7 +1552,9 @@ struct server_context {
14921552
return;
14931553
}
14941554

1495-
size_t idx = result.data["index"];
1555+
const size_t idx = result.data["index"];
1556+
GGML_ASSERT(idx < results.size() && "index out of range");
1557+
14961558
results[idx] = result;
14971559
}
14981560
result_handler(results);
@@ -1951,6 +2013,29 @@ struct server_context {
19512013
}
19522014

19532015
prompt_tokens = embd_inp;
2016+
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2017+
// require slot.prompt to be array of 2 strings
2018+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2019+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2020+
slot.release();
2021+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2022+
continue;
2023+
}
2024+
2025+
// prompt: <s>query</s><s>doc</s>
2026+
prompt_tokens.clear();
2027+
prompt_tokens.push_back(llama_token_bos(model));
2028+
{
2029+
const auto part = tokenize(slot.prompt[0], false);
2030+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2031+
}
2032+
prompt_tokens.push_back(llama_token_eos(model));
2033+
prompt_tokens.push_back(llama_token_bos(model));
2034+
{
2035+
const auto part = tokenize(slot.prompt[1], false);
2036+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2037+
}
2038+
prompt_tokens.push_back(llama_token_eos(model));
19542039
} else {
19552040
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
19562041
}
@@ -1970,7 +2055,7 @@ struct server_context {
19702055
continue;
19712056
}
19722057

1973-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2058+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
19742059
// this prompt is too large to process - discard it
19752060
if (slot.n_prompt_tokens > n_ubatch) {
19762061
slot.release();
@@ -2048,15 +2133,18 @@ struct server_context {
20482133
slot.n_prompt_tokens_processed = 0;
20492134
}
20502135

2051-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2136+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
20522137
// cannot fit the prompt in the current batch - will try next iter
20532138
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
20542139
continue;
20552140
}
20562141
}
20572142

20582143
// check that we are in the right batch_type, if not defer the slot
2059-
bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0;
2144+
const bool slot_type =
2145+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2146+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
2147+
20602148
if (batch_type == -1) {
20612149
batch_type = slot_type;
20622150
} else if (batch_type != slot_type) {
@@ -2229,6 +2317,13 @@ struct server_context {
22292317
continue; // continue loop of slots
22302318
}
22312319

2320+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2321+
send_rank(slot, batch_view);
2322+
slot.release();
2323+
slot.i_batch = -1;
2324+
continue; // continue loop of slots
2325+
}
2326+
22322327
// prompt evaluated for next-token prediction
22332328
slot.state = SLOT_STATE_GENERATING;
22342329
} else if (slot.state != SLOT_STATE_GENERATING) {
@@ -3023,6 +3118,82 @@ int main(int argc, char ** argv) {
30233118
res_ok(res, root);
30243119
};
30253120

3121+
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3122+
const json body = json::parse(req.body);
3123+
3124+
// TODO: implement
3125+
//int top_n = 1;
3126+
//if (body.count("top_n") != 1) {
3127+
// top_n = body.at("top_n");
3128+
//} else {
3129+
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3130+
// return;
3131+
//}
3132+
3133+
json query;
3134+
if (body.count("query") == 1) {
3135+
query = body.at("query");
3136+
if (!query.is_string()) {
3137+
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3138+
return;
3139+
}
3140+
} else {
3141+
exit(0);
3142+
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3143+
return;
3144+
}
3145+
3146+
json documents;
3147+
if (body.count("documents") != 0) {
3148+
documents = body.at("documents");
3149+
if (!documents.is_array() || documents.size() == 0) {
3150+
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3151+
return;
3152+
}
3153+
} else {
3154+
res_error(res, format_error_response("\"documents\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3155+
return;
3156+
}
3157+
3158+
// construct prompt object: array of ["query", "doc0", "doc1", ...]
3159+
json prompt;
3160+
prompt.push_back(query);
3161+
for (const auto & doc : documents) {
3162+
prompt.push_back(doc);
3163+
}
3164+
3165+
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3166+
3167+
// create and queue the task
3168+
json responses = json::array();
3169+
bool error = false;
3170+
{
3171+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3172+
ctx_server.queue_results.add_waiting_tasks(tasks);
3173+
ctx_server.queue_tasks.post(tasks);
3174+
3175+
// get the result
3176+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3177+
3178+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3179+
for (const auto & res : results) {
3180+
responses.push_back(res.data);
3181+
}
3182+
}, [&](const json & error_data) {
3183+
res_error(res, error_data);
3184+
error = true;
3185+
});
3186+
}
3187+
3188+
if (error) {
3189+
return;
3190+
}
3191+
3192+
// write JSON response
3193+
json root = format_response_rerank(body, responses);
3194+
res_ok(res, root);
3195+
};
3196+
30263197
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
30273198
json result = json::array();
30283199
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
@@ -3119,6 +3290,7 @@ int main(int argc, char ** argv) {
31193290
svr->Post("/embedding", handle_embeddings); // legacy
31203291
svr->Post("/embeddings", handle_embeddings);
31213292
svr->Post("/v1/embeddings", handle_embeddings);
3293+
svr->Post("/v1/rerank", handle_rerank);
31223294
svr->Post("/tokenize", handle_tokenize);
31233295
svr->Post("/detokenize", handle_detokenize);
31243296
// LoRA adapters hotswap

examples/server/utils.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
537537
json res = json {
538538
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
539539
{"object", "list"},
540-
{"usage", json {
540+
{"usage", json { // TODO: fill
541541
{"prompt_tokens", 0},
542542
{"total_tokens", 0}
543543
}},
@@ -547,6 +547,29 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
547547
return res;
548548
}
549549

550+
static json format_response_rerank(const json & request, const json & ranks) {
551+
json data = json::array();
552+
int i = 0;
553+
for (const auto & rank : ranks) {
554+
data.push_back(json{
555+
{"index", i++},
556+
{"relevance_score", json_value(rank, "rank", 0.0)},
557+
});
558+
}
559+
560+
json res = json {
561+
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
562+
{"object", "list"},
563+
{"usage", json { // TODO: fill
564+
{"prompt_tokens", 0},
565+
{"total_tokens", 0}
566+
}},
567+
{"results", data}
568+
};
569+
570+
return res;
571+
}
572+
550573
static bool is_valid_utf8(const std::string & str) {
551574
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
552575
const unsigned char* end = bytes + str.length();

0 commit comments

Comments
 (0)