Skip to content

Commit 8f36df8

Browse files
committed
server: fix a race condition cause by "request_completion"
1 parent d083c81 commit 8f36df8

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

examples/server/server.cpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,9 +1122,10 @@ struct llama_server_context
11221122
queue_results.send(res);
11231123
}
11241124

1125-
int request_completion(json data, bool infill, bool embedding, int multitask_id)
1125+
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
11261126
{
11271127
task_server task;
1128+
task.id = task_id;
11281129
task.target_id = 0;
11291130
task.data = std::move(data);
11301131
task.infill_mode = infill;
@@ -1135,11 +1136,11 @@ struct llama_server_context
11351136
// when a completion task's prompt array is not a singleton, we split it into multiple requests
11361137
if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
11371138
{
1138-
return split_multiprompt_task(task);
1139+
split_multiprompt_task(task_id, task);
11391140
}
11401141

11411142
// otherwise, it's a single-prompt task, we actually queue it
1142-
return queue_tasks.post(task);
1143+
queue_tasks.post(task);
11431144
}
11441145

11451146
// for multiple images processing
@@ -1218,25 +1219,30 @@ struct llama_server_context
12181219
queue_tasks.post(task);
12191220
}
12201221

1221-
int split_multiprompt_task(task_server& multiprompt_task)
1222+
void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
12221223
{
12231224
int prompt_count = multiprompt_task.data.at("prompt").size();
12241225
assert(prompt_count > 1);
12251226

1226-
int multitask_id = queue_tasks.get_next_id();
1227+
// generate all the ID for subtask
12271228
std::vector<int> subtask_ids(prompt_count);
12281229
for (int i = 0; i < prompt_count; i++)
1230+
{
1231+
subtask_ids[i] = queue_tasks.get_new_id();
1232+
}
1233+
1234+
// queue up the multitask so we can track its subtask progression
1235+
queue_tasks.add_multitask(multitask_id, subtask_ids);
1236+
1237+
// add subtasks
1238+
for (int i = 0; i < prompt_count; i++)
12291239
{
12301240
json subtask_data = multiprompt_task.data;
12311241
subtask_data["prompt"] = subtask_data["prompt"][i];
12321242

12331243
// subtasks inherit everything else (infill mode, embedding mode, etc.)
1234-
subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
1244+
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
12351245
}
1236-
1237-
// queue up the multitask so we can track its subtask progression
1238-
queue_tasks.add_multitask(multitask_id, subtask_ids);
1239-
return multitask_id;
12401246
}
12411247

12421248
void process_single_task(task_server& task)
@@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
24932499
return;
24942500
}
24952501
json data = json::parse(req.body);
2496-
const int task_id = llama.request_completion(data, false, false, -1);
2502+
const int task_id = llama.queue_tasks.get_new_id();
24972503
llama.queue_results.add_waiting_task_id(task_id);
2504+
llama.request_completion(task_id, data, false, false, -1);
24982505
if (!json_value(data, "stream", false)) {
24992506
std::string completion_text;
25002507
task_result result = llama.queue_results.recv(task_id);
@@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
25052512
{
25062513
res.status = 404;
25072514
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
2508-
llama.queue_results.remove_waiting_task_id(task_id);
2509-
return;
25102515
}
2516+
llama.queue_results.remove_waiting_task_id(task_id);
25112517
} else {
25122518
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink)
25132519
{
@@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
25462552
break;
25472553
}
25482554
}
2549-
sink.done();
2555+
25502556
llama.queue_results.remove_waiting_task_id(task_id);
2557+
sink.done();
25512558
return true;
25522559
};
25532560

@@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
25922599
}
25932600
json data = oaicompat_completion_params_parse(json::parse(req.body));
25942601

2595-
const int task_id = llama.request_completion(data, false, false, -1);
2602+
const int task_id = llama.queue_tasks.get_new_id();
25962603
llama.queue_results.add_waiting_task_id(task_id);
2604+
llama.request_completion(task_id, data, false, false, -1);
25972605

25982606
if (!json_value(data, "stream", false)) {
25992607
std::string completion_text;
@@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
26082616
} else {
26092617
res.status = 500;
26102618
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
2611-
llama.queue_results.remove_waiting_task_id(task_id);
2612-
return;
26132619
}
2620+
llama.queue_results.remove_waiting_task_id(task_id);
26142621
} else {
26152622
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
26162623
while (true) {
@@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
26712678
return;
26722679
}
26732680
json data = json::parse(req.body);
2674-
const int task_id = llama.request_completion(data, true, false, -1);
2681+
const int task_id = llama.queue_tasks.get_new_id();
2682+
llama.queue_results.add_waiting_task_id(task_id);
2683+
llama.request_completion(task_id, data, true, false, -1);
26752684
if (!json_value(data, "stream", false)) {
26762685
std::string completion_text;
26772686
task_result result = llama.queue_results.recv(task_id);
@@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
26832692
{
26842693
res.status = 404;
26852694
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
2686-
return;
26872695
}
2696+
llama.queue_results.remove_waiting_task_id(task_id);
26882697
} else {
26892698
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) {
26902699
while (true)
@@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
27002709
});
27012710
if (!sink.write(str.c_str(), str.size()))
27022711
{
2712+
llama.queue_results.remove_waiting_task_id(task_id);
27032713
return false;
27042714
}
27052715
if (result.stop)
@@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
27132723
}
27142724
}
27152725

2726+
llama.queue_results.remove_waiting_task_id(task_id);
27162727
sink.done();
2717-
27182728
return true;
27192729
};
27202730

@@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
27882798
image_data = "";
27892799
}
27902800

2791-
const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
2801+
// create and queue the task
2802+
const int task_id = llama.queue_tasks.get_new_id();
2803+
llama.queue_results.add_waiting_task_id(task_id);
2804+
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
2805+
2806+
// get the result
27922807
task_result result = llama.queue_results.recv(task_id);
2808+
llama.queue_results.remove_waiting_task_id(task_id);
2809+
2810+
// send the result
27932811
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
27942812
});
27952813

examples/server/utils.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ struct llama_server_queue {
203203
// Add a new task to the end of the queue
204204
int post(task_server task) {
205205
std::unique_lock<std::mutex> lock(mutex_tasks);
206-
task.id = id++;
206+
if (task.id == -1) {
207+
task.id = id++;
208+
}
207209
queue_tasks.push_back(std::move(task));
208210
condition_tasks.notify_one();
209211
return task.id;
@@ -215,8 +217,8 @@ struct llama_server_queue {
215217
queue_tasks_deferred.push_back(std::move(task));
216218
}
217219

218-
// Get the next task id
219-
int get_next_id() {
220+
// Get the next id for creating anew task
221+
int get_new_id() {
220222
std::unique_lock<std::mutex> lock(mutex_tasks);
221223
return id++;
222224
}

0 commit comments

Comments
 (0)