From cd108e641dbdedd8c5641c4cec1762f751f38136 Mon Sep 17 00:00:00 2001
From: Behnam M <58621210+ibehnam@users.noreply.github.com>
Date: Wed, 10 Jan 2024 14:56:05 -0500
Subject: [PATCH] server : add a `/health` endpoint (#4860)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line
---
 examples/server/server.cpp | 199 +++++++++++++++++++++----------------
 1 file changed, 113 insertions(+), 86 deletions(-)

diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 6c7fcd176c87f..1cca634d5461f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -26,6 +26,7 @@
 #include <mutex>
 #include <chrono>
 #include <condition_variable>
+#include <atomic>
 
 #ifndef SERVER_VERBOSE
 #define SERVER_VERBOSE 1
@@ -146,6 +147,12 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
 // parallel
 //
 
+enum ServerState {
+    LOADING_MODEL,  // Server is starting up, model not fully loaded yet
+    READY,          // Server is ready and model is loaded
+    ERROR           // An error occurred, load_model failed
+};
+
 enum task_type {
     COMPLETION_TASK,
     CANCEL_TASK
@@ -2453,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
     }
 }
 
-
 static std::string random_string()
 {
     static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
@@ -2790,15 +2796,117 @@ int main(int argc, char **argv)
                                 {"system_info", llama_print_system_info()},
                             });
 
-    // load the model
-    if (!llama.load_model(params))
+    httplib::Server svr;
+
+    std::atomic<ServerState> server_state{LOADING_MODEL};
+
+    svr.set_default_headers({{"Server", "llama.cpp"},
+                             {"Access-Control-Allow-Origin", "*"},
+                             {"Access-Control-Allow-Headers", "content-type"}});
+
+    svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
+        ServerState current_state = server_state.load();
+        switch(current_state) {
+            case READY:
+                res.set_content(R"({"status": "ok"})", "application/json");
+                res.status = 200; // HTTP OK
+                break;
+            case LOADING_MODEL:
+                res.set_content(R"({"status": "loading model"})", "application/json");
+                res.status = 503; // HTTP Service Unavailable
+                break;
+            case ERROR:
+                res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
+                res.status = 500; // HTTP Internal Server Error
+                break;
+        }
+    });
+
+    svr.set_logger(log_server_request);
+
+    svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
+            {
+                const char fmt[] = "500 Internal Server Error\n%s";
+                char buf[BUFSIZ];
+                try
+                {
+                    std::rethrow_exception(std::move(ep));
+                }
+                catch (std::exception &e)
+                {
+                    snprintf(buf, sizeof(buf), fmt, e.what());
+                }
+                catch (...)
+                {
+                    snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
+                }
+                res.set_content(buf, "text/plain; charset=utf-8");
+                res.status = 500;
+            });
+
+    svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
+            {
+                if (res.status == 401)
+                {
+                    res.set_content("Unauthorized", "text/plain; charset=utf-8");
+                }
+                if (res.status == 400)
+                {
+                    res.set_content("Invalid request", "text/plain; charset=utf-8");
+                }
+                else if (res.status == 404)
+                {
+                    res.set_content("File Not Found", "text/plain; charset=utf-8");
+                    res.status = 404;
+                }
+            });
+
+    // set timeouts and change hostname and port
+    svr.set_read_timeout (sparams.read_timeout);
+    svr.set_write_timeout(sparams.write_timeout);
+
+    if (!svr.bind_to_port(sparams.hostname, sparams.port))
     {
+        fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
         return 1;
     }
 
-    llama.initialize();
+    // Set the base directory for serving static files
+    svr.set_base_dir(sparams.public_path);
 
-    httplib::Server svr;
+    // to make it ctrl+clickable:
+    LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
+
+    std::unordered_map<std::string, std::string> log_data;
+    log_data["hostname"] = sparams.hostname;
+    log_data["port"] = std::to_string(sparams.port);
+
+    if (!sparams.api_key.empty()) {
+        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
+    }
+
+    LOG_INFO("HTTP server listening", log_data);
+    // run the HTTP server in a thread - see comment below
+    std::thread t([&]()
+            {
+                if (!svr.listen_after_bind())
+                {
+                    server_state.store(ERROR);
+                    return 1;
+                }
+
+                return 0;
+            });
+
+    // load the model
+    if (!llama.load_model(params))
+    {
+        server_state.store(ERROR);
+        return 1;
+    } else {
+        llama.initialize();
+        server_state.store(READY);
+    }
 
     // Middleware for API key validation
     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
@@ -2826,10 +2934,6 @@ int main(int argc, char **argv)
         return false;
     };
 
-    svr.set_default_headers({{"Server", "llama.cpp"},
-                             {"Access-Control-Allow-Origin", "*"},
-                             {"Access-Control-Allow-Headers", "content-type"}});
-
     // this is only called if no index.html is found in the public --path
     svr.Get("/", [](const httplib::Request &, httplib::Response &res)
             {
@@ -2937,8 +3041,6 @@ int main(int argc, char **argv)
                 }
             });
 
-
-
     svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
             {
                 std::time_t t = std::time(0);
@@ -3157,81 +3259,6 @@ int main(int argc, char **argv)
                 return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
             });
 
-    svr.set_logger(log_server_request);
-
-    svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
-            {
-                const char fmt[] = "500 Internal Server Error\n%s";
-                char buf[BUFSIZ];
-                try
-                {
-                    std::rethrow_exception(std::move(ep));
-                }
-                catch (std::exception &e)
-                {
-                    snprintf(buf, sizeof(buf), fmt, e.what());
-                }
-                catch (...)
-                {
-                    snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
-                }
-                res.set_content(buf, "text/plain; charset=utf-8");
-                res.status = 500;
-            });
-
-    svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
-            {
-                if (res.status == 401)
-                {
-                    res.set_content("Unauthorized", "text/plain; charset=utf-8");
-                }
-                if (res.status == 400)
-                {
-                    res.set_content("Invalid request", "text/plain; charset=utf-8");
-                }
-                else if (res.status == 404)
-                {
-                    res.set_content("File Not Found", "text/plain; charset=utf-8");
-                    res.status = 404;
-                }
-            });
-
-    // set timeouts and change hostname and port
-    svr.set_read_timeout (sparams.read_timeout);
-    svr.set_write_timeout(sparams.write_timeout);
-
-    if (!svr.bind_to_port(sparams.hostname, sparams.port))
-    {
-        fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
-        return 1;
-    }
-
-    // Set the base directory for serving static files
-    svr.set_base_dir(sparams.public_path);
-
-    // to make it ctrl+clickable:
-    LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
-
-    std::unordered_map<std::string, std::string> log_data;
-    log_data["hostname"] = sparams.hostname;
-    log_data["port"] = std::to_string(sparams.port);
-
-    if (!sparams.api_key.empty()) {
-        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
-    }
-
-    LOG_INFO("HTTP server listening", log_data);
-    // run the HTTP server in a thread - see comment below
-    std::thread t([&]()
-            {
-                if (!svr.listen_after_bind())
-                {
-                    return 1;
-                }
-
-                return 0;
-            });
-
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
     //     "Bus error: 10" - this is on macOS, it does not crash on Linux
     //std::thread t2([&]()