Skip to content

Commit 9487165

Browse files
committed
make task creation scoped
1 parent d8656a1 commit 9487165

File tree

1 file changed

+127
-107
lines changed

1 file changed

+127
-107
lines changed

examples/server/server.cpp

Lines changed: 127 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3633,14 +3633,17 @@ int main(int argc, char ** argv) {
36333633
}
36343634

36353635
// request slots data using task queue
3636-
server_task task(SERVER_TASK_TYPE_METRICS);
3637-
task.id = ctx_server.queue_tasks.get_new_id();
3638-
ctx_server.queue_results.add_waiting_task_id(task.id);
3639-
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
3636+
int task_id = ctx_server.queue_tasks.get_new_id();
3637+
{
3638+
server_task task(SERVER_TASK_TYPE_METRICS);
3639+
task.id = task_id;
3640+
ctx_server.queue_results.add_waiting_task_id(task_id);
3641+
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
3642+
}
36403643

36413644
// get the result
3642-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3643-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3645+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
3646+
ctx_server.queue_results.remove_waiting_task_id(task_id);
36443647

36453648
if (result->is_error()) {
36463649
res_error(res, result->to_json());
@@ -3669,16 +3672,17 @@ int main(int argc, char ** argv) {
36693672
}
36703673

36713674
// request slots data using task queue
3672-
server_task task(SERVER_TASK_TYPE_METRICS);
3673-
task.id = ctx_server.queue_tasks.get_new_id();
3674-
task.metrics_reset_bucket = true;
3675-
3676-
ctx_server.queue_results.add_waiting_task_id(task.id);
3677-
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
3675+
int task_id = ctx_server.queue_tasks.get_new_id();
3676+
{
3677+
server_task task(SERVER_TASK_TYPE_METRICS);
3678+
task.id = task_id;
3679+
ctx_server.queue_results.add_waiting_task_id(task_id);
3680+
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
3681+
}
36783682

36793683
// get the result
3680-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3681-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3684+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
3685+
ctx_server.queue_results.remove_waiting_task_id(task_id);
36823686

36833687
if (result->is_error()) {
36843688
res_error(res, result->to_json());
@@ -3775,17 +3779,20 @@ int main(int argc, char ** argv) {
37753779
}
37763780
std::string filepath = params.slot_save_path + filename;
37773781

3778-
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
3779-
task.id = ctx_server.queue_tasks.get_new_id();
3780-
task.slot_action.slot_id = id_slot;
3781-
task.slot_action.filename = filename;
3782-
task.slot_action.filepath = filepath;
3782+
int task_id = ctx_server.queue_tasks.get_new_id();
3783+
{
3784+
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
3785+
task.id = ctx_server.queue_tasks.get_new_id();
3786+
task.slot_action.slot_id = id_slot;
3787+
task.slot_action.filename = filename;
3788+
task.slot_action.filepath = filepath;
37833789

3784-
ctx_server.queue_results.add_waiting_task_id(task.id);
3785-
ctx_server.queue_tasks.post(std::move(task));
3790+
ctx_server.queue_results.add_waiting_task_id(task_id);
3791+
ctx_server.queue_tasks.post(std::move(task));
3792+
}
37863793

3787-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3788-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3794+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
3795+
ctx_server.queue_results.remove_waiting_task_id(task_id);
37893796

37903797
if (result->is_error()) {
37913798
res_error(res, result->to_json());
@@ -3804,17 +3811,20 @@ int main(int argc, char ** argv) {
38043811
}
38053812
std::string filepath = params.slot_save_path + filename;
38063813

3807-
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
3808-
task.id = ctx_server.queue_tasks.get_new_id();
3809-
task.slot_action.slot_id = id_slot;
3810-
task.slot_action.filename = filename;
3811-
task.slot_action.filepath = filepath;
3814+
int task_id = ctx_server.queue_tasks.get_new_id();
3815+
{
3816+
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
3817+
task.id = ctx_server.queue_tasks.get_new_id();
3818+
task.slot_action.slot_id = id_slot;
3819+
task.slot_action.filename = filename;
3820+
task.slot_action.filepath = filepath;
38123821

3813-
ctx_server.queue_results.add_waiting_task_id(task.id);
3814-
ctx_server.queue_tasks.post(std::move(task));
3822+
ctx_server.queue_results.add_waiting_task_id(task_id);
3823+
ctx_server.queue_tasks.post(std::move(task));
3824+
}
38153825

3816-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3817-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3826+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
3827+
ctx_server.queue_results.remove_waiting_task_id(task_id);
38183828

38193829
if (result->is_error()) {
38203830
res_error(res, result->to_json());
@@ -3826,15 +3836,18 @@ int main(int argc, char ** argv) {
38263836
};
38273837

38283838
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
3829-
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
3830-
task.id = ctx_server.queue_tasks.get_new_id();
3831-
task.slot_action.slot_id = id_slot;
3839+
int task_id = ctx_server.queue_tasks.get_new_id();
3840+
{
3841+
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
3842+
task.id = ctx_server.queue_tasks.get_new_id();
3843+
task.slot_action.slot_id = id_slot;
38323844

3833-
ctx_server.queue_results.add_waiting_task_id(task.id);
3834-
ctx_server.queue_tasks.post(std::move(task));
3845+
ctx_server.queue_results.add_waiting_task_id(task_id);
3846+
ctx_server.queue_tasks.post(std::move(task));
3847+
}
38353848

3836-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3837-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3849+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
3850+
ctx_server.queue_results.remove_waiting_task_id(task_id);
38383851

38393852
if (result->is_error()) {
38403853
res_error(res, result->to_json());
@@ -3938,45 +3951,48 @@ int main(int argc, char ** argv) {
39383951
}
39393952

39403953
auto completion_id = gen_chatcmplid();
3941-
std::vector<server_task> tasks;
3942-
3943-
try {
3944-
const auto & prompt = data.at("prompt");
3945-
// TODO: this log can become very long, put it behind a flag or think about a more compact format
3946-
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
3947-
3948-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3949-
tasks.reserve(tokenized_prompts.size());
3950-
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3951-
server_task task = server_task(type);
3952-
3953-
task.id = ctx_server.queue_tasks.get_new_id();
3954-
task.index = i;
3955-
3956-
task.prompt_tokens = std::move(tokenized_prompts[i]);
3957-
task.params = server_task::params_from_json_cmpl(
3958-
ctx_server.ctx,
3959-
ctx_server.params_base,
3960-
data);
3961-
task.id_selected_slot = json_value(data, "id_slot", -1);
3962-
3963-
// OAI-compat
3964-
task.params.oaicompat = oaicompat;
3965-
task.params.oaicompat_cmpl_id = completion_id;
3966-
// oaicompat_model is already populated by params_from_json_cmpl
3954+
std::unordered_set<int> task_ids;
3955+
{
3956+
std::vector<server_task> tasks;
39673957

3968-
tasks.push_back(std::move(task));
3958+
try {
3959+
const auto & prompt = data.at("prompt");
3960+
// TODO: this log can become very long, put it behind a flag or think about a more compact format
3961+
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
3962+
3963+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
3964+
tasks.reserve(tokenized_prompts.size());
3965+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3966+
server_task task = server_task(type);
3967+
3968+
task.id = ctx_server.queue_tasks.get_new_id();
3969+
task.index = i;
3970+
3971+
task.prompt_tokens = std::move(tokenized_prompts[i]);
3972+
task.params = server_task::params_from_json_cmpl(
3973+
ctx_server.ctx,
3974+
ctx_server.params_base,
3975+
data);
3976+
task.id_selected_slot = json_value(data, "id_slot", -1);
3977+
3978+
// OAI-compat
3979+
task.params.oaicompat = oaicompat;
3980+
task.params.oaicompat_cmpl_id = completion_id;
3981+
// oaicompat_model is already populated by params_from_json_cmpl
3982+
3983+
tasks.push_back(std::move(task));
3984+
}
3985+
} catch (const std::exception & e) {
3986+
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
3987+
return;
39693988
}
3970-
} catch (const std::exception & e) {
3971-
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
3972-
return;
3973-
}
39743989

3975-
ctx_server.queue_results.add_waiting_tasks(tasks);
3976-
ctx_server.queue_tasks.post(std::move(tasks));
3990+
task_ids = server_task::get_list_id(tasks);
3991+
ctx_server.queue_results.add_waiting_tasks(tasks);
3992+
ctx_server.queue_tasks.post(std::move(tasks));
3993+
}
39773994

39783995
bool stream = json_value(data, "stream", false);
3979-
const auto task_ids = server_task::get_list_id(tasks);
39803996

39813997
if (!stream) {
39823998
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
@@ -4268,6 +4284,7 @@ int main(int argc, char ** argv) {
42684284
// create and queue the task
42694285
json responses = json::array();
42704286
bool error = false;
4287+
std::unordered_set<int> task_ids;
42714288
{
42724289
std::vector<server_task> tasks;
42734290
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -4283,24 +4300,23 @@ int main(int argc, char ** argv) {
42834300
tasks.push_back(std::move(task));
42844301
}
42854302

4303+
task_ids = server_task::get_list_id(tasks);
42864304
ctx_server.queue_results.add_waiting_tasks(tasks);
42874305
ctx_server.queue_tasks.post(std::move(tasks));
4306+
}
42884307

4289-
// get the result
4290-
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
4308+
// get the result
4309+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
4310+
for (auto & res : results) {
4311+
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
4312+
responses.push_back(res->to_json());
4313+
}
4314+
}, [&](const json & error_data) {
4315+
res_error(res, error_data);
4316+
error = true;
4317+
}, req.is_connection_closed);
42914318

4292-
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
4293-
for (auto & res : results) {
4294-
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
4295-
responses.push_back(res->to_json());
4296-
}
4297-
}, [&](const json & error_data) {
4298-
res_error(res, error_data);
4299-
error = true;
4300-
}, req.is_connection_closed);
4301-
4302-
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
4303-
}
4319+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
43044320

43054321
if (error) {
43064322
return;
@@ -4367,6 +4383,7 @@ int main(int argc, char ** argv) {
43674383
// create and queue the task
43684384
json responses = json::array();
43694385
bool error = false;
4386+
std::unordered_set<int> task_ids;
43704387
{
43714388
std::vector<server_task> tasks;
43724389
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
@@ -4379,23 +4396,21 @@ int main(int argc, char ** argv) {
43794396
tasks.push_back(std::move(task));
43804397
}
43814398

4399+
task_ids = server_task::get_list_id(tasks);
43824400
ctx_server.queue_results.add_waiting_tasks(tasks);
43834401
ctx_server.queue_tasks.post(std::move(tasks));
4384-
4385-
// get the result
4386-
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
4387-
4388-
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
4389-
for (auto & res : results) {
4390-
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
4391-
responses.push_back(res->to_json());
4392-
}
4393-
}, [&](const json & error_data) {
4394-
res_error(res, error_data);
4395-
error = true;
4396-
}, req.is_connection_closed);
43974402
}
43984403

4404+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
4405+
for (auto & res : results) {
4406+
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
4407+
responses.push_back(res->to_json());
4408+
}
4409+
}, [&](const json & error_data) {
4410+
res_error(res, error_data);
4411+
error = true;
4412+
}, req.is_connection_closed);
4413+
43994414
if (error) {
44004415
return;
44014416
}
@@ -4431,14 +4446,19 @@ int main(int argc, char ** argv) {
44314446
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
44324447
return;
44334448
}
4434-
server_task task(SERVER_TASK_TYPE_SET_LORA);
4435-
task.id = ctx_server.queue_tasks.get_new_id();
4436-
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
4437-
ctx_server.queue_results.add_waiting_task_id(task.id);
4438-
ctx_server.queue_tasks.post(std::move(task));
44394449

4440-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
4441-
ctx_server.queue_results.remove_waiting_task_id(task.id);
4450+
int task_id = ctx_server.queue_tasks.get_new_id();
4451+
{
4452+
server_task task(SERVER_TASK_TYPE_SET_LORA);
4453+
task.id = ctx_server.queue_tasks.get_new_id();
4454+
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
4455+
ctx_server.queue_results.add_waiting_task_id(task_id);
4456+
ctx_server.queue_tasks.post(std::move(task));
4457+
}
4458+
4459+
// get the result
4460+
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
4461+
ctx_server.queue_results.remove_waiting_task_id(task_id);
44424462

44434463
if (result->is_error()) {
44444464
res_error(res, result->to_json());

0 commit comments

Comments
 (0)