From a0ebdfcc5d27d0438fe1555b35596d847a47691f Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Sat, 16 Mar 2024 11:32:29 +0100 Subject: [PATCH] common: llama_load_model_from_url witch to libcurl dependency --- common/CMakeLists.txt | 14 +-- common/common.cpp | 173 +++++++++++-------------------------- examples/main/README.md | 1 + examples/server/README.md | 1 + examples/server/server.cpp | 6 +- 5 files changed, 64 insertions(+), 131 deletions(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index d275ef5a65a57..79c3abdfede8e 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -47,14 +47,14 @@ if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() -# Check for OpenSSL -find_package(OpenSSL QUIET) -if (OPENSSL_FOUND) - add_definitions(-DHAVE_OPENSSL) - include_directories(${OPENSSL_INCLUDE_DIR}) - link_libraries(${OPENSSL_LIBRARIES}) +# Check for curl +find_package(CURL QUIET) +if (CURL_FOUND) + add_definitions(-DHAVE_CURL) + include_directories(${CURL_INCLUDE_DIRS}) + link_libraries(${CURL_LIBRARIES}) else() - message(WARNING "OpenSSL not found. Building without model download support.") + message(INFO "libcurl not found. Building without model download support.") endif () diff --git a/common/common.cpp b/common/common.cpp index baa2ad2f9d62f..4f955df30a116 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -16,6 +16,9 @@ #include #include #include +#ifdef HAVE_CURL +#include +#endif #if defined(__APPLE__) && defined(__MACH__) #include @@ -531,6 +534,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "-mu" || arg == "--model-url") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_url = argv[i]; } else if (arg == "-md" || arg == "--model-draft") { if (++i >= argc) { invalid_param = true; @@ -1131,6 +1140,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); + printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); + printf(" model download url (default: %s)\n", params.model_url.c_str()); printf(" -md FNAME, --model-draft FNAME\n"); printf(" draft model for speculative decoding\n"); printf(" -ld LOGDIR, --logdir LOGDIR\n"); @@ -1376,150 +1387,70 @@ void llama_batch_add( batch.n_tokens++; } +#ifdef HAVE_CURL struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, struct llama_model_params params) { -#ifdef HAVE_OPENSSL - // Initialize OpenSSL - SSL_library_init(); - SSL_load_error_strings(); - OpenSSL_add_all_algorithms(); - - // Parse the URL to extract host, path, user, and password - char host[256]; - char path[256]; - char userpass[256]; - - if (sscanf(model_url, "https://%255[^/]/%255s", host, path) != 2) { - fprintf(stderr, "%s: invalid URL format: %s\n", __func__, model_url); - return nullptr; - } - - if (strstr(host, "@")) { - sscanf(host, "%[^@]@%s", userpass, host); - } - - // Create an SSL context - auto ctx = SSL_CTX_new(TLS_client_method()); - if (!ctx) { - fprintf(stderr, "%s: error creating SSL context\n", __func__); - return nullptr; - } - - // Set up certificate verification - SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr); - - // Load trusted CA certificates based on platform - const char* ca_cert_path = nullptr; -#ifdef _WIN32 - ca_cert_path = "C:\\path\\to\\ca-certificates.crt"; // Windows path (FIXME) -#elif __APPLE__ - ca_cert_path = "/etc/ssl/cert.pem"; // macOS path -#else - ca_cert_path = "/etc/ssl/certs/ca-certificates.crt"; // Linux path -#endif + // Initialize libcurl + curl_global_init(CURL_GLOBAL_DEFAULT); + auto curl = curl_easy_init(); - if (!SSL_CTX_load_verify_locations(ctx, ca_cert_path, nullptr)) { - fprintf(stderr, "%s: error loading CA certificates\n", __func__); - SSL_CTX_free(ctx); - return nullptr; - } - - // Create an SSL connection - auto bio = BIO_new_ssl_connect(ctx); - if (!bio) { - fprintf(stderr, "%s: error creating SSL connection\n", __func__); - SSL_CTX_free(ctx); - return nullptr; - } - // Set the hostname - if (!BIO_set_conn_hostname(bio, host)) { - fprintf(stderr, "%s: unable to set connection hostname %s\n", __func__, host); - BIO_free_all(bio); - SSL_CTX_free(ctx); + if (!curl) { + curl_global_cleanup(); + fprintf(stderr, "%s: error initializing lib curl\n", __func__); return nullptr; } - // Construct the HTTP request - char request[1024]; - snprintf(request, sizeof(request), "GET /%s HTTP/1.1\r\nHost: %s\r\nAccept: */*\r\nUser-Agent: llama-client\r\nConnection: close\r\n", path, host); + // Set the URL + curl_easy_setopt(curl, CURLOPT_URL, model_url); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - // Add Authorization header if user credentials are available - if (strlen(userpass) > 0) { - char auth_header[256]; - snprintf(auth_header, sizeof(auth_header), "Authorization: Basic %s\r\n", userpass); - strcat(request, auth_header); - } - - // End of headers - strcat(request, "\r\n"); - - // Send the request - fprintf(stdout, "%s: downloading model from https://%s/%s to %s ...\n", __func__, host, path, path_model); - if (!BIO_puts(bio, request)) { - fprintf(stderr, "%s: error sending HTTP request https://%s/%s\n", __func__, host, path); - BIO_free_all(bio); - SSL_CTX_free(ctx); + // Set the output file + auto outfile = fopen(path_model, "wb"); + if (!outfile) { + curl_easy_cleanup(curl); + curl_global_cleanup(); + fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model); return nullptr; } + curl_easy_setopt(curl, CURLOPT_WRITEDATA, outfile); - // Read the response status line - char status_line[256]; - if (BIO_gets(bio, status_line, sizeof(status_line)) <= 0) { - fprintf(stderr, "%s: error reading response status line\n", __func__); - BIO_free_all(bio); - SSL_CTX_free(ctx); + // start the download + fprintf(stdout, "%s: downloading model from %s to %s ...\n", __func__, model_url, path_model); + auto res = curl_easy_perform(curl); + if (res != CURLE_OK) { + fclose(outfile); + curl_easy_cleanup(curl); + curl_global_cleanup(); + fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); return nullptr; } - // Verify HTTP status code - if (strncmp(status_line, "HTTP/1.1 200", 12) != 0) { - fprintf(stderr, "%s: HTTP request failed: %s\n", __func__, status_line); - BIO_free_all(bio); - SSL_CTX_free(ctx); + int http_code = 0; + curl_easy_getinfo (curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + fclose(outfile); + curl_easy_cleanup(curl); + curl_global_cleanup(); + fprintf(stderr, "%s: invalid http status code failed: %d\n", __func__, http_code); return nullptr; } - // Skip response headers - char buffer[4096]; - int n_bytes_received; - while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) { - // Look for the end of headers (empty line) - if (strstr(buffer, "\r\n\r\n")) { - break; - } - } - - // Read and save the file content - FILE* outfile = fopen(path_model, "wb"); - if (!outfile) { - fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model); - BIO_free_all(bio); - SSL_CTX_free(ctx); - return nullptr; - } - - int n_bytes_received_total = 0; - while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) { - fwrite(buffer, 1, n_bytes_received, outfile); - n_bytes_received_total += n_bytes_received; - if (n_bytes_received_total % (1024 * 1024) == 0) { - fprintf(stdout, "%s: model downloading %dGi %s ...\n", __func__, n_bytes_received_total / 1024 / 1024, path_model); - } - } - fclose(outfile); - // Clean up - BIO_free_all(bio); - SSL_CTX_free(ctx); - fprintf(stdout, "%s: model downloaded from https://%s/%s to %s.\n", __func__, host, path, path_model); + fclose(outfile); + curl_easy_cleanup(curl); + curl_global_cleanup(); return llama_load_model_from_file(path_model, params); +} #else - LLAMA_LOG_ERROR("llama.cpp built without SSL support, downloading from url not supported.\n", __func__); +struct llama_model * llama_load_model_from_url(const char *, const char *, + struct llama_model_params) { + fprintf(stderr, "%s: llama.cpp built without SSL support, downloading from url not supported.\n", __func__); return nullptr; -#endif } +#endif std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); diff --git a/examples/main/README.md b/examples/main/README.md index 7f84e42623274..daaa807d55952 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -67,6 +67,7 @@ main.exe -m models\7B\ggml-model.bin --ignore-eos -n -1 --random-prompt In this section, we cover the most commonly used options for running the `main` program with the LLaMA models: - `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). +- `-mu MODEL_URL --model MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf). - `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. - `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models. - `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. diff --git a/examples/server/README.md b/examples/server/README.md index 8f8454affaecd..df1ccce9bebe0 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -20,6 +20,7 @@ The project is under active development, and we are [looking for feedback and co - `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation. - `--threads-http N`: number of threads in the http server pool to process requests (default: `max(std::thread::hardware_concurrency() - 1, --parallel N + 2)`) - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). +- `-mu MODEL_URL --model MODEL_URL`: Specify a remote http url to download the file (e.g https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf). - `-a ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5e1020009cbf1..d2a8e541d3305 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2195,8 +2195,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co } printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); - printf(" -u MODEL_URL, --url MODEL_URL\n"); - printf(" model url (default: %s)\n", params.model_url.c_str()); + printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); + printf(" model download url (default: %s)\n", params.model_url.c_str()); printf(" -a ALIAS, --alias ALIAS\n"); printf(" set an alias for the model, will be added as `model` field in completion response\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -2319,7 +2319,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } params.model = argv[i]; - } else if (arg == "-u" || arg == "--model-url") { + } else if (arg == "-mu" || arg == "--model-url") { if (++i >= argc) { invalid_param = true; break;