From 72b12c3be7f7d8b2e0d1fb703e6d6973caff6493 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 29 Jan 2024 12:58:17 -0800 Subject: [PATCH] Bump llama.cpp to b1999 This requires an upstream change to support graceful termination, carried as a patch. --- llm/ext_server/ext_server.cpp | 57 +++++++++++++--------- llm/llama.cpp | 2 +- llm/patches/01-cache.diff | 8 ++-- llm/patches/02-shutdown.diff | 90 +++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 27 deletions(-) create mode 100644 llm/patches/02-shutdown.diff diff --git a/llm/ext_server/ext_server.cpp b/llm/ext_server/ext_server.cpp index 635a1f68..b59b46d2 100644 --- a/llm/ext_server/ext_server.cpp +++ b/llm/ext_server/ext_server.cpp @@ -26,13 +26,13 @@ // Expose the llama server as a callable extern "C" API llama_server_context *llama = NULL; -std::atomic ext_server_running(false); std::thread ext_server_thread; void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { assert(err != NULL && sparams != NULL); log_set_target(stderr); if (!sparams->verbose_logging) { + server_verbose = true; log_disable(); } @@ -122,18 +122,23 @@ void llama_server_start() { assert(llama != NULL); // TODO mutex to protect thread creation ext_server_thread = std::thread([&]() { - ext_server_running = true; try { LOG_TEE("llama server main loop starting\n"); ggml_time_init(); - while (ext_server_running.load()) { - if (!llama->update_slots()) { - LOG_TEE( - "unexpected error in llama server update_slots - exiting main " - "loop\n"); - break; - } - } + llama->queue_tasks.on_new_task(std::bind( + &llama_server_context::process_single_task, llama, std::placeholders::_1)); + llama->queue_tasks.on_finish_multitask(std::bind( + &llama_server_context::on_finish_multitask, llama, std::placeholders::_1)); + llama->queue_tasks.on_all_tasks_finished(std::bind( + &llama_server_context::run_on_all_tasks_finished, llama)); + llama->queue_results.on_multitask_update(std::bind( + &llama_server_queue::update_multitask, + &llama->queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + llama->queue_tasks.start_loop(); } catch (std::exception &e) { LOG_TEE("caught exception in llama server main loop: %s\n", e.what()); } catch (...) { @@ -146,13 +151,10 @@ void llama_server_start() { void llama_server_stop() { assert(llama != NULL); - // TODO - too verbose, remove once things are solid - LOG_TEE("requesting llama server shutdown\n"); - ext_server_running = false; - - // unblocks the update_slots() loop so it can clean up and exit - llama->request_cancel(0); - + LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n"); + // This may take a while for any pending tasks to drain + // TODO - consider a timeout to cancel tasks if it's taking too long + llama->queue_tasks.terminate(); ext_server_thread.join(); delete llama; llama = NULL; @@ -165,7 +167,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) { resp->msg[0] = '\0'; try { json data = json::parse(json_req); - resp->id = llama->request_completion(data, false, false, -1); + resp->id = llama->queue_tasks.get_new_id(); + llama->queue_results.add_waiting_task_id(resp->id); + llama->request_completion(resp->id, data, false, false, -1); } catch (std::exception &e) { snprintf(resp->msg, resp->msg_len, "exception %s", e.what()); } catch (...) { @@ -183,16 +187,22 @@ void llama_server_completion_next_result(const int task_id, resp->json_resp = NULL; std::string result_json; try { - task_result result = llama->next_result(task_id); + task_result result = llama->queue_results.recv(task_id); result_json = result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); resp->id = result.id; resp->stop = result.stop; resp->error = result.error; if (result.error) { + LOG_TEE("next result cancel on error\n"); llama->request_cancel(task_id); + LOG_TEE("next result removing waiting tak ID: %d\n", task_id); + llama->queue_results.remove_waiting_task_id(task_id); } else if (result.stop) { + LOG_TEE("next result cancel on stop\n"); llama->request_cancel(task_id); + LOG_TEE("next result removing waiting task ID: %d\n", task_id); + llama->queue_results.remove_waiting_task_id(task_id); } } catch (std::exception &e) { resp->error = true; @@ -223,6 +233,7 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) { err->msg[0] = '\0'; try { llama->request_cancel(task_id); + llama->queue_results.remove_waiting_task_id(task_id); } catch (std::exception &e) { err->id = -1; snprintf(err->msg, err->msg_len, "exception %s", e.what()); @@ -307,13 +318,15 @@ void llama_server_embedding(const char *json_req, char **json_resp, } else { prompt = ""; } - const int task_id = llama->request_completion( - {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); - task_result result = llama->next_result(task_id); + const int task_id = llama->queue_tasks.get_new_id(); + llama->queue_results.add_waiting_task_id(task_id); + llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); + task_result result = llama->queue_results.recv(task_id); std::string result_json = result.result_json.dump(); const std::string::size_type size = result_json.size() + 1; *json_resp = new char[size]; snprintf(*json_resp, size, "%s", result_json.c_str()); + llama->queue_results.remove_waiting_task_id(task_id); } catch (std::exception &e) { err->id = -1; snprintf(err->msg, err->msg_len, "exception %s", e.what()); diff --git a/llm/llama.cpp b/llm/llama.cpp index cd4fddb2..d2f650cb 160000 --- a/llm/llama.cpp +++ b/llm/llama.cpp @@ -1 +1 @@ -Subproject commit cd4fddb29f81d6a1f6d51a0c016bc6b486d68def +Subproject commit d2f650cb5b04ee2726663e79b47da5efe196ce00 diff --git a/llm/patches/01-cache.diff b/llm/patches/01-cache.diff index f8392495..79f8d002 100644 --- a/llm/patches/01-cache.diff +++ b/llm/patches/01-cache.diff @@ -1,8 +1,8 @@ diff --git a/examples/server/server.cpp b/examples/server/server.cpp -index 0462fbd2..4fa7b57f 100644 +index a48582ad..9fffffd8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp -@@ -1857,12 +1857,6 @@ struct llama_server_context +@@ -1564,12 +1564,6 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -15,8 +15,8 @@ index 0462fbd2..4fa7b57f 100644 if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. -@@ -1870,6 +1864,12 @@ struct llama_server_context - slot.n_past--; +@@ -1581,6 +1575,12 @@ struct llama_server_context + } } + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); diff --git a/llm/patches/02-shutdown.diff b/llm/patches/02-shutdown.diff new file mode 100644 index 00000000..4c247cc0 --- /dev/null +++ b/llm/patches/02-shutdown.diff @@ -0,0 +1,90 @@ +diff --git a/examples/server/server.cpp b/examples/server/server.cpp +index 11dd82c3..311495a8 100644 +--- a/examples/server/server.cpp ++++ b/examples/server/server.cpp +@@ -28,6 +28,7 @@ + #include + #include + #include ++#include + + using json = nlohmann::json; + +@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con + } + } + ++std::function shutdown_handler; ++inline void signal_handler(int signal) { shutdown_handler(signal); } ++ + int main(int argc, char **argv) + { + #if SERVER_VERBOSE != 1 +@@ -3014,8 +3018,14 @@ int main(int argc, char **argv) + std::placeholders::_2, + std::placeholders::_3 + )); +- llama.queue_tasks.start_loop(); + ++ shutdown_handler = [&](int) { ++ llama.queue_tasks.terminate(); ++ }; ++ signal(SIGTERM, signal_handler); ++ signal(SIGINT, signal_handler); ++ llama.queue_tasks.start_loop(); ++ svr.stop(); + t.join(); + + llama_backend_free(); +diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp +index 70cce072..2acb1eab 100644 +--- a/examples/server/utils.hpp ++++ b/examples/server/utils.hpp +@@ -6,6 +6,7 @@ + #include + #include + #include ++#include + + #include "json.hpp" + +@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector messages) + struct llama_server_queue { + int id = 0; + std::mutex mutex_tasks; ++ std::atomic running; + // queues + std::vector queue_tasks; + std::vector queue_tasks_deferred; +@@ -248,9 +250,15 @@ struct llama_server_queue { + queue_tasks_deferred.clear(); + } + +- // Start the main loop. This call is blocking +- [[noreturn]] ++ // end the start_loop routine ++ void terminate() { ++ running = false; ++ condition_tasks.notify_all(); ++ } ++ ++ // Start the main loop. + void start_loop() { ++ running = true; + while (true) { + // new task arrived + LOG_VERBOSE("have new task", {}); +@@ -294,8 +302,12 @@ struct llama_server_queue { + { + std::unique_lock lock(mutex_tasks); + if (queue_tasks.empty()) { ++ if (!running.load()) { ++ LOG_VERBOSE("ending start_loop", {}); ++ return; ++ } + condition_tasks.wait(lock, [&]{ +- return !queue_tasks.empty(); ++ return (!queue_tasks.empty() || !running.load()); + }); + } + }