@@ -1122,9 +1122,10 @@ struct llama_server_context
1122
1122
queue_results.send (res);
1123
1123
}
1124
1124
1125
- int request_completion (json data, bool infill, bool embedding, int multitask_id)
1125
+ void request_completion (int task_id, json data, bool infill, bool embedding, int multitask_id)
1126
1126
{
1127
1127
task_server task;
1128
+ task.id = task_id;
1128
1129
task.target_id = 0 ;
1129
1130
task.data = std::move (data);
1130
1131
task.infill_mode = infill;
@@ -1135,11 +1136,11 @@ struct llama_server_context
1135
1136
// when a completion task's prompt array is not a singleton, we split it into multiple requests
1136
1137
if (task.data .count (" prompt" ) && task.data .at (" prompt" ).size () > 1 )
1137
1138
{
1138
- return split_multiprompt_task (task);
1139
+ split_multiprompt_task (task_id, task);
1139
1140
}
1140
1141
1141
1142
// otherwise, it's a single-prompt task, we actually queue it
1142
- return queue_tasks.post (task);
1143
+ queue_tasks.post (task);
1143
1144
}
1144
1145
1145
1146
// for multiple images processing
@@ -1218,25 +1219,30 @@ struct llama_server_context
1218
1219
queue_tasks.post (task);
1219
1220
}
1220
1221
1221
- int split_multiprompt_task (task_server& multiprompt_task)
1222
+ void split_multiprompt_task (int multitask_id, task_server& multiprompt_task)
1222
1223
{
1223
1224
int prompt_count = multiprompt_task.data .at (" prompt" ).size ();
1224
1225
assert (prompt_count > 1 );
1225
1226
1226
- int multitask_id = queue_tasks. get_next_id ();
1227
+ // generate all the ID for subtask
1227
1228
std::vector<int > subtask_ids (prompt_count);
1228
1229
for (int i = 0 ; i < prompt_count; i++)
1230
+ {
1231
+ subtask_ids[i] = queue_tasks.get_new_id ();
1232
+ }
1233
+
1234
+ // queue up the multitask so we can track its subtask progression
1235
+ queue_tasks.add_multitask (multitask_id, subtask_ids);
1236
+
1237
+ // add subtasks
1238
+ for (int i = 0 ; i < prompt_count; i++)
1229
1239
{
1230
1240
json subtask_data = multiprompt_task.data ;
1231
1241
subtask_data[" prompt" ] = subtask_data[" prompt" ][i];
1232
1242
1233
1243
// subtasks inherit everything else (infill mode, embedding mode, etc.)
1234
- subtask_ids[i] = request_completion ( subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1244
+ request_completion ( subtask_ids[i], subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1235
1245
}
1236
-
1237
- // queue up the multitask so we can track its subtask progression
1238
- queue_tasks.add_multitask (multitask_id, subtask_ids);
1239
- return multitask_id;
1240
1246
}
1241
1247
1242
1248
void process_single_task (task_server& task)
@@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
2493
2499
return ;
2494
2500
}
2495
2501
json data = json::parse (req.body );
2496
- const int task_id = llama.request_completion (data, false , false , - 1 );
2502
+ const int task_id = llama.queue_tasks . get_new_id ( );
2497
2503
llama.queue_results .add_waiting_task_id (task_id);
2504
+ llama.request_completion (task_id, data, false , false , -1 );
2498
2505
if (!json_value (data, " stream" , false )) {
2499
2506
std::string completion_text;
2500
2507
task_result result = llama.queue_results .recv (task_id);
@@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
2505
2512
{
2506
2513
res.status = 404 ;
2507
2514
res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2508
- llama.queue_results .remove_waiting_task_id (task_id);
2509
- return ;
2510
2515
}
2516
+ llama.queue_results .remove_waiting_task_id (task_id);
2511
2517
} else {
2512
2518
const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink & sink)
2513
2519
{
@@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
2546
2552
break ;
2547
2553
}
2548
2554
}
2549
- sink. done ();
2555
+
2550
2556
llama.queue_results .remove_waiting_task_id (task_id);
2557
+ sink.done ();
2551
2558
return true ;
2552
2559
};
2553
2560
@@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
2592
2599
}
2593
2600
json data = oaicompat_completion_params_parse (json::parse (req.body ));
2594
2601
2595
- const int task_id = llama.request_completion (data, false , false , - 1 );
2602
+ const int task_id = llama.queue_tasks . get_new_id ( );
2596
2603
llama.queue_results .add_waiting_task_id (task_id);
2604
+ llama.request_completion (task_id, data, false , false , -1 );
2597
2605
2598
2606
if (!json_value (data, " stream" , false )) {
2599
2607
std::string completion_text;
@@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
2608
2616
} else {
2609
2617
res.status = 500 ;
2610
2618
res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2611
- llama.queue_results .remove_waiting_task_id (task_id);
2612
- return ;
2613
2619
}
2620
+ llama.queue_results .remove_waiting_task_id (task_id);
2614
2621
} else {
2615
2622
const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink &sink) {
2616
2623
while (true ) {
@@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
2671
2678
return ;
2672
2679
}
2673
2680
json data = json::parse (req.body );
2674
- const int task_id = llama.request_completion (data, true , false , -1 );
2681
+ const int task_id = llama.queue_tasks .get_new_id ();
2682
+ llama.queue_results .add_waiting_task_id (task_id);
2683
+ llama.request_completion (task_id, data, true , false , -1 );
2675
2684
if (!json_value (data, " stream" , false )) {
2676
2685
std::string completion_text;
2677
2686
task_result result = llama.queue_results .recv (task_id);
@@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
2683
2692
{
2684
2693
res.status = 404 ;
2685
2694
res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2686
- return ;
2687
2695
}
2696
+ llama.queue_results .remove_waiting_task_id (task_id);
2688
2697
} else {
2689
2698
const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink & sink) {
2690
2699
while (true )
@@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
2700
2709
});
2701
2710
if (!sink.write (str.c_str (), str.size ()))
2702
2711
{
2712
+ llama.queue_results .remove_waiting_task_id (task_id);
2703
2713
return false ;
2704
2714
}
2705
2715
if (result.stop )
@@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
2713
2723
}
2714
2724
}
2715
2725
2726
+ llama.queue_results .remove_waiting_task_id (task_id);
2716
2727
sink.done ();
2717
-
2718
2728
return true ;
2719
2729
};
2720
2730
@@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
2788
2798
image_data = " " ;
2789
2799
}
2790
2800
2791
- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 }, {" image_data" , image_data} }, false , true , -1 );
2801
+ // create and queue the task
2802
+ const int task_id = llama.queue_tasks .get_new_id ();
2803
+ llama.queue_results .add_waiting_task_id (task_id);
2804
+ llama.request_completion (task_id, { {" prompt" , prompt}, { " n_predict" , 0 }, {" image_data" , image_data} }, false , true , -1 );
2805
+
2806
+ // get the result
2792
2807
task_result result = llama.queue_results .recv (task_id);
2808
+ llama.queue_results .remove_waiting_task_id (task_id);
2809
+
2810
+ // send the result
2793
2811
return res.set_content (result.result_json .dump (), " application/json; charset=utf-8" );
2794
2812
});
2795
2813
0 commit comments