diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 846ef7e5fee4f..81d16fb107530 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -341,6 +341,65 @@ struct llama_client_slot {"t_total", t_prompt_processing + t_token_generation}, }); } + + // context extension via Self-Extend + void grp_attn_update_params() { + int grpa_i = 0; + // copy to local variables + int32_t grpa_n = ga_n; + int32_t grpa_w = ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < n_past; ++k) + { + while (slot_npast >= grpa_i + grpa_w) { + const int bd = (grpa_w/grpa_n)*(grpa_n - 1); + slot_npast -= bd; + grpa_i += grpa_w/grpa_n; + } + slot_npast++; + } + n_past_se = slot_npast; + ga_i = grpa_i; + } + + int32_t grp_attn_calc_npast() { + int32_t slot_npast = n_past_se > 0 ? n_past_se : n_past; + // copy to local variables + int32_t grpa_i = ga_i; + int32_t grpa_n = ga_n; + int32_t grpa_w = ga_w; + while (slot_npast >= grpa_i + grpa_w) { + const int bd = (grpa_w/grpa_n)*(grpa_n - 1); + slot_npast -= bd; + grpa_i += grpa_w/grpa_n; + } + return slot_npast; + } + + void grp_attn_shift(llama_context * ctx, const int32_t n_tokens) { + while (n_past_se >= ga_i + ga_w) + { + const int ib = (ga_n * ga_i) / ga_w; + const int bd = (ga_w / ga_n) * (ga_n - 1); + const int dd = (ga_w / ga_n) - ib * bd - ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past_se, ib * bd, ga_i + ib * bd, n_past_se + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib * bd, ga_i + ib * bd + ga_w, ga_n, (ga_i + ib * bd) / ga_n, (ga_i + ib * bd + ga_w) / ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib * bd + ga_w, n_past_se + ib * bd, dd, ga_i + ib * bd + ga_w + dd, n_past_se + ib * bd + dd); + + llama_kv_cache_seq_add(ctx, id, ga_i, n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, id, ga_i + ib * bd, ga_i + ib * bd + ga_w,ga_n); + llama_kv_cache_seq_add(ctx, id, ga_i + ib * bd + ga_w,n_past_se + ib * bd, dd); + + n_past_se -= bd; + + ga_i += ga_w / ga_n; + + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past_se + bd, n_past_se, ga_i); + } + n_past_se += n_tokens; + } }; struct llama_metrics { @@ -1120,13 +1179,23 @@ struct llama_server_context return slot.images.size() > 0; } - void send_error(task_server& task, const std::string &error) + void send_error(task_server &task, const std::string &error) { - LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); + send_error(task.id, task.multitask_id, error); + } + + void send_error(llama_client_slot &slot, const std::string &error) + { + send_error(slot.task_id, slot.multitask_id, error); + } + + void send_error(int task_id, int multitask_id, const std::string &error) + { + LOG_TEE("task %i - error: %s\n", task_id, error.c_str()); task_result res; - res.id = task.id; - res.multitask_id = task.multitask_id; - res.stop = false; + res.id = task_id; + res.multitask_id = multitask_id; + res.stop = true; res.error = true; res.result_json = { { "content", error } }; queue_results.send(res); @@ -1593,7 +1662,9 @@ struct llama_server_context queue_results.send(result); } - bool update_slots() { + void run_slots() { + bool has_next_response = false; // whether to schedule next slot run, to generate next token + if (system_need_update) { LOG_INFO("updating system prompt", {}); @@ -1609,15 +1680,9 @@ struct llama_server_context LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {}); kv_cache_clear(); } - return true; + return; } - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - task_server task; - task.type = TASK_TYPE_NEXT_RESPONSE; - task.target_id = -1; - queue_tasks.post(task); - for (llama_client_slot &slot : slots) { if (slot.ga_n == 1) @@ -1815,21 +1880,8 @@ struct llama_server_context if (slot.ga_n != 1) { - int ga_i = 0; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - int32_t slot_npast = 0; - for (int k = 0; k < slot.n_past; ++k) - { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - slot_npast++; - } - slot.n_past_se = slot_npast; - slot.ga_i = ga_i; + // context extension via Self-Extend + slot.grp_attn_update_params(); } LOG_INFO("slot progression", { @@ -1875,22 +1927,16 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; + int32_t slot_npast = slot.n_past; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { if (slot.ga_n != 1) { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } + // context extension via Self-Extend + slot_npast = slot.grp_attn_calc_npast(); } + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); slot_npast++; } @@ -1901,10 +1947,8 @@ struct llama_server_context "slot_id", slot.id, "task_id", slot.task_id, }); - // FIXME @phymbert: to be properly tested - // early returning without changing the slot state will block the slot for ever - // no one at the moment is checking the return value - return false; + send_error(slot, "failed processing images"); + continue; } // extract the logits only for the last token @@ -1922,9 +1966,9 @@ struct llama_server_context if (batch.n_tokens == 0) { all_slots_are_idle = true; - return true; } + // loop of n_batch for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); @@ -1934,28 +1978,9 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_se >= slot.ga_i + slot.ga_w) - { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); - } - slot.n_past_se += n_tokens; + // TODO @ngxson: What happen if we're retrying with smaller n_batch? + // By the second time we retry, "grp_attn_shift" has already been called + slot.grp_attn_shift(ctx, n_tokens); } } @@ -1979,7 +2004,13 @@ struct llama_server_context { // if you get here, it means the KV cache is full - try increasing it via the context size LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); - return false; + for (auto & slot : slots) + { + send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); + slot.release(); + } + has_next_response = false; + break; // break loop of n_batch } LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); @@ -1987,14 +2018,15 @@ struct llama_server_context // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; i -= n_batch; - continue; + continue; // continue loop of n_batch } + // loop of slots for (auto & slot : slots) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; + continue; // continue loop of slots } // prompt evaluated for embedding @@ -2003,7 +2035,7 @@ struct llama_server_context send_embedding(slot); slot.release(); slot.i_batch = -1; - continue; + continue; // continue loop of slots } completion_token_output result; @@ -2042,16 +2074,25 @@ struct llama_server_context metrics.on_prediction(slot); } + // if slot is not yet finish its work, we schedule next run + if (slot.has_next_token) + { + has_next_response = true; + } + slot.i_batch = -1; } } - LOG_VERBOSE("slots updated", {}); - return true; - } + if (has_next_response) { + LOG_VERBOSE("schedule next slot run", {}); + task_server task; + task.type = TASK_TYPE_NEXT_RESPONSE; + task.target_id = -1; + queue_tasks.post(task); + } - void run_on_all_tasks_finished() { - update_slots(); + LOG_VERBOSE("slots run completed", {}); } }; @@ -3494,7 +3535,7 @@ int main(int argc, char **argv) bool running = true; while (running) { - running = llama.update_slots(); + running = llama.run_slots(); } }*/ //); @@ -3516,8 +3557,8 @@ int main(int argc, char **argv) &llama_server_context::process_single_task, &llama, std::placeholders::_1)); llama.queue_tasks.on_finish_multitask(std::bind( &llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); - llama.queue_tasks.on_all_tasks_finished(std::bind( - &llama_server_context::run_on_all_tasks_finished, &llama)); + llama.queue_tasks.on_run_slots(std::bind( + &llama_server_context::run_slots, &llama)); llama.queue_results.on_multitask_update(std::bind( &llama_server_queue::update_multitask, &llama.queue_tasks, diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d7abd7cbba71c..2cf50ab689d2d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -37,10 +37,6 @@ extern bool server_log_json; #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) -// -// parallel -// - enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded @@ -250,7 +246,7 @@ struct llama_server_queue { // callback functions std::function callback_new_task; std::function callback_finish_multitask; - std::function callback_all_task_finished; + std::function callback_run_slots; // Add a new task to the end of the queue int post(task_server task) { @@ -283,14 +279,14 @@ struct llama_server_queue { callback_new_task = callback; } - // Register function to process a multitask + // Register function to process a multitask when it is finished void on_finish_multitask(std::function callback) { callback_finish_multitask = callback; } - // Register the function to be called when the batch of tasks is finished - void on_all_tasks_finished(std::function callback) { - callback_all_task_finished = callback; + // Register the function to be called when all slots data is ready to be processed + void on_run_slots(std::function callback) { + callback_run_slots = callback; } // Call when the state of one slot is changed @@ -312,7 +308,13 @@ struct llama_server_queue { condition_tasks.notify_all(); } - // Start the main loop. + /** + * Main loop consists of these steps: + * - Wait until a new task arrives + * - Process the task (i.e. maybe copy data into slot) + * - Check if multitask is finished + * - Run all slots + */ void start_loop() { running = true; while (true) { @@ -331,8 +333,8 @@ struct llama_server_queue { LOG_VERBOSE("callback_new_task", {{"task_id", task.id}}); callback_new_task(task); } - LOG_VERBOSE("callback_all_task_finished", {}); - // process and update all the multitasks + LOG_VERBOSE("update_multitasks", {}); + // check if we have any finished multitasks auto queue_iterator = queue_multitasks.begin(); while (queue_iterator != queue_multitasks.end()) { @@ -349,8 +351,9 @@ struct llama_server_queue { ++queue_iterator; } } - // all tasks in the current loop is finished - callback_all_task_finished(); + // all tasks in the current loop is processed, slots data is now ready + LOG_VERBOSE("callback_run_slots", {}); + callback_run_slots(); } LOG_VERBOSE("wait for new task", {}); // wait for new task @@ -408,16 +411,26 @@ struct llama_server_response { std::mutex mutex_results; std::condition_variable condition_results; + // add the task_id to the list of tasks waiting for response void add_waiting_task_id(int task_id) { LOG_VERBOSE("waiting for task id", {{"task_id", task_id}}); std::unique_lock lock(mutex_results); waiting_task_ids.insert(task_id); } + // when thr request is finished, we can remove task associated with it void remove_waiting_task_id(int task_id) { LOG_VERBOSE("remove waiting for task id", {{"task_id", task_id}}); std::unique_lock lock(mutex_results); waiting_task_ids.erase(task_id); + // also clear pending results, just in case + for (int i = 0; i < (int) queue_results.size(); i++) + { + if (queue_results[i].id == task_id) + { + queue_results.erase(queue_results.begin() + i); + } + } } // This function blocks the thread until there is a response for this task_id @@ -435,6 +448,7 @@ struct llama_server_response { { assert(queue_results[i].multitask_id == -1); task_result res = queue_results[i]; + LOG_VERBOSE("got task result", {{"task_id", res.id}}); queue_results.erase(queue_results.begin() + i); return res; }