@@ -1122,9 +1122,10 @@ struct llama_server_context
11221122 queue_results.send (res);
11231123 }
11241124
1125- int request_completion (json data, bool infill, bool embedding, int multitask_id)
1125+ void request_completion (int task_id, json data, bool infill, bool embedding, int multitask_id)
11261126 {
11271127 task_server task;
1128+ task.id = task_id;
11281129 task.target_id = 0 ;
11291130 task.data = std::move (data);
11301131 task.infill_mode = infill;
@@ -1135,11 +1136,11 @@ struct llama_server_context
11351136 // when a completion task's prompt array is not a singleton, we split it into multiple requests
11361137 if (task.data .count (" prompt" ) && task.data .at (" prompt" ).size () > 1 )
11371138 {
1138- return split_multiprompt_task (task);
1139+ split_multiprompt_task (task_id, task);
11391140 }
11401141
11411142 // otherwise, it's a single-prompt task, we actually queue it
1142- return queue_tasks.post (task);
1143+ queue_tasks.post (task);
11431144 }
11441145
11451146 // for multiple images processing
@@ -1218,25 +1219,30 @@ struct llama_server_context
12181219 queue_tasks.post (task);
12191220 }
12201221
1221- int split_multiprompt_task (task_server& multiprompt_task)
1222+ void split_multiprompt_task (int multitask_id, task_server& multiprompt_task)
12221223 {
12231224 int prompt_count = multiprompt_task.data .at (" prompt" ).size ();
12241225 assert (prompt_count > 1 );
12251226
1226- int multitask_id = queue_tasks. get_next_id ();
1227+ // generate all the ID for subtask
12271228 std::vector<int > subtask_ids (prompt_count);
12281229 for (int i = 0 ; i < prompt_count; i++)
1230+ {
1231+ subtask_ids[i] = queue_tasks.get_new_id ();
1232+ }
1233+
1234+ // queue up the multitask so we can track its subtask progression
1235+ queue_tasks.add_multitask (multitask_id, subtask_ids);
1236+
1237+ // add subtasks
1238+ for (int i = 0 ; i < prompt_count; i++)
12291239 {
12301240 json subtask_data = multiprompt_task.data ;
12311241 subtask_data[" prompt" ] = subtask_data[" prompt" ][i];
12321242
12331243 // subtasks inherit everything else (infill mode, embedding mode, etc.)
1234- subtask_ids[i] = request_completion ( subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
1244+ request_completion ( subtask_ids[i], subtask_data, multiprompt_task.infill_mode , multiprompt_task.embedding_mode , multitask_id);
12351245 }
1236-
1237- // queue up the multitask so we can track its subtask progression
1238- queue_tasks.add_multitask (multitask_id, subtask_ids);
1239- return multitask_id;
12401246 }
12411247
12421248 void process_single_task (task_server& task)
@@ -2493,8 +2499,9 @@ int main(int argc, char **argv)
24932499 return ;
24942500 }
24952501 json data = json::parse (req.body );
2496- const int task_id = llama.request_completion (data, false , false , - 1 );
2502+ const int task_id = llama.queue_tasks . get_new_id ( );
24972503 llama.queue_results .add_waiting_task_id (task_id);
2504+ llama.request_completion (task_id, data, false , false , -1 );
24982505 if (!json_value (data, " stream" , false )) {
24992506 std::string completion_text;
25002507 task_result result = llama.queue_results .recv (task_id);
@@ -2505,9 +2512,8 @@ int main(int argc, char **argv)
25052512 {
25062513 res.status = 404 ;
25072514 res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2508- llama.queue_results .remove_waiting_task_id (task_id);
2509- return ;
25102515 }
2516+ llama.queue_results .remove_waiting_task_id (task_id);
25112517 } else {
25122518 const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink & sink)
25132519 {
@@ -2546,8 +2552,9 @@ int main(int argc, char **argv)
25462552 break ;
25472553 }
25482554 }
2549- sink. done ();
2555+
25502556 llama.queue_results .remove_waiting_task_id (task_id);
2557+ sink.done ();
25512558 return true ;
25522559 };
25532560
@@ -2592,8 +2599,9 @@ int main(int argc, char **argv)
25922599 }
25932600 json data = oaicompat_completion_params_parse (json::parse (req.body ));
25942601
2595- const int task_id = llama.request_completion (data, false , false , - 1 );
2602+ const int task_id = llama.queue_tasks . get_new_id ( );
25962603 llama.queue_results .add_waiting_task_id (task_id);
2604+ llama.request_completion (task_id, data, false , false , -1 );
25972605
25982606 if (!json_value (data, " stream" , false )) {
25992607 std::string completion_text;
@@ -2608,9 +2616,8 @@ int main(int argc, char **argv)
26082616 } else {
26092617 res.status = 500 ;
26102618 res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2611- llama.queue_results .remove_waiting_task_id (task_id);
2612- return ;
26132619 }
2620+ llama.queue_results .remove_waiting_task_id (task_id);
26142621 } else {
26152622 const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink &sink) {
26162623 while (true ) {
@@ -2671,7 +2678,9 @@ int main(int argc, char **argv)
26712678 return ;
26722679 }
26732680 json data = json::parse (req.body );
2674- const int task_id = llama.request_completion (data, true , false , -1 );
2681+ const int task_id = llama.queue_tasks .get_new_id ();
2682+ llama.queue_results .add_waiting_task_id (task_id);
2683+ llama.request_completion (task_id, data, true , false , -1 );
26752684 if (!json_value (data, " stream" , false )) {
26762685 std::string completion_text;
26772686 task_result result = llama.queue_results .recv (task_id);
@@ -2683,8 +2692,8 @@ int main(int argc, char **argv)
26832692 {
26842693 res.status = 404 ;
26852694 res.set_content (result.result_json [" content" ], " text/plain; charset=utf-8" );
2686- return ;
26872695 }
2696+ llama.queue_results .remove_waiting_task_id (task_id);
26882697 } else {
26892698 const auto chunked_content_provider = [task_id, &llama](size_t , httplib::DataSink & sink) {
26902699 while (true )
@@ -2700,6 +2709,7 @@ int main(int argc, char **argv)
27002709 });
27012710 if (!sink.write (str.c_str (), str.size ()))
27022711 {
2712+ llama.queue_results .remove_waiting_task_id (task_id);
27032713 return false ;
27042714 }
27052715 if (result.stop )
@@ -2713,8 +2723,8 @@ int main(int argc, char **argv)
27132723 }
27142724 }
27152725
2726+ llama.queue_results .remove_waiting_task_id (task_id);
27162727 sink.done ();
2717-
27182728 return true ;
27192729 };
27202730
@@ -2788,8 +2798,16 @@ int main(int argc, char **argv)
27882798 image_data = " " ;
27892799 }
27902800
2791- const int task_id = llama.request_completion ({ {" prompt" , prompt}, { " n_predict" , 0 }, {" image_data" , image_data} }, false , true , -1 );
2801+ // create and queue the task
2802+ const int task_id = llama.queue_tasks .get_new_id ();
2803+ llama.queue_results .add_waiting_task_id (task_id);
2804+ llama.request_completion (task_id, { {" prompt" , prompt}, { " n_predict" , 0 }, {" image_data" , image_data} }, false , true , -1 );
2805+
2806+ // get the result
27922807 task_result result = llama.queue_results .recv (task_id);
2808+ llama.queue_results .remove_waiting_task_id (task_id);
2809+
2810+ // send the result
27932811 return res.set_content (result.result_json .dump (), " application/json; charset=utf-8" );
27942812 });
27952813
0 commit comments