From a86ad841f103e471ac0fd7ee8852d1eb5015ce89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Y=C3=BCg?= Date: Tue, 10 Dec 2024 17:22:34 +0000 Subject: [PATCH] server : add flag to disable the web-ui (#10762) (#10751) Co-authored-by: eugenio.segala --- common/arg.cpp | 7 ++++++ examples/server/README.md | 1 + examples/server/server.cpp | 30 ++++++++++++++---------- examples/server/tests/unit/test_basic.py | 18 ++++++++++++++ examples/server/tests/utils.py | 3 +++ 5 files changed, 46 insertions(+), 13 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 0db59f70187e8..808ec1c3e65fc 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1711,6 +1711,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.public_path = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); + add_opt(common_arg( + {"--no-webui"}, + string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), + [](common_params & params) { + params.webui = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_WEBUI")); add_opt(common_arg( {"--embedding", "--embeddings"}, string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), diff --git a/examples/server/README.md b/examples/server/README.md index 117c52d3f1310..6294f541fc7d7 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -146,6 +146,7 @@ The project is under active development, and we are [looking for feedback and co | `--host HOST` | ip address to listen (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) | | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | +| `--no-webui` | disable the Web UI
(env: LLAMA_ARG_NO_WEBUI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 47bfd6c4aac90..8cb992470a302 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3815,20 +3815,24 @@ int main(int argc, char ** argv) { // Router // - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point("/", params.public_path); - if (!is_found) { - LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); } else { - // using embedded static index.html - svr->Get("/", [](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(index_html), index_html_len, "text/html; charset=utf-8"); - return false; - }); + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = svr->set_mount_point("/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + svr->Get("/", [](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(index_html), index_html_len, "text/html; charset=utf-8"); + return false; + }); + } } // register API routes diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py index 1d512401685fb..1485de8ceb3fc 100644 --- a/examples/server/tests/unit/test_basic.py +++ b/examples/server/tests/unit/test_basic.py @@ -1,4 +1,5 @@ import pytest +import requests from utils import * server = ServerPreset.tinyllama2() @@ -76,3 +77,20 @@ def test_load_split_model(): }) assert res.status_code == 200 assert match_regex("(little|girl)+", res.body["content"]) + + +def test_no_webui(): + global server + # default: webui enabled + server.start() + url = f"http://{server.server_host}:{server.server_port}" + res = requests.get(url) + assert res.status_code == 200 + assert "" in res.text + server.stop() + + # with --no-webui + server.no_webui = True + server.start() + res = requests.get(url) + assert res.status_code == 404 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 7c89b9cd37505..d988ccf5e3061 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -72,6 +72,7 @@ class ServerProcess: disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None + no_webui: bool | None = None # session variables process: subprocess.Popen | None = None @@ -158,6 +159,8 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--draft-max", self.draft_max]) if self.draft_min: server_args.extend(["--draft-min", self.draft_min]) + if self.no_webui: + server_args.append("--no-webui") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}")