diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 350bbdf7f7b1b..d275ef5a65a57 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -47,6 +47,16 @@ if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() +# Check for OpenSSL +find_package(OpenSSL QUIET) +if (OPENSSL_FOUND) + add_definitions(-DHAVE_OPENSSL) + include_directories(${OPENSSL_INCLUDE_DIR}) + link_libraries(${OPENSSL_LIBRARIES}) +else() + message(WARNING "OpenSSL not found. Building without model download support.") +endif () + set(TARGET common) diff --git a/common/common.cpp b/common/common.cpp index 4912237e0d0f1..baa2ad2f9d62f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1376,10 +1376,160 @@ void llama_batch_add( batch.n_tokens++; } +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, + struct llama_model_params params) { +#ifdef HAVE_OPENSSL + // Initialize OpenSSL + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + + // Parse the URL to extract host, path, user, and password + char host[256]; + char path[256]; + char userpass[256]; + + if (sscanf(model_url, "https://%255[^/]/%255s", host, path) != 2) { + fprintf(stderr, "%s: invalid URL format: %s\n", __func__, model_url); + return nullptr; + } + + if (strstr(host, "@")) { + sscanf(host, "%[^@]@%s", userpass, host); + } + + // Create an SSL context + auto ctx = SSL_CTX_new(TLS_client_method()); + if (!ctx) { + fprintf(stderr, "%s: error creating SSL context\n", __func__); + return nullptr; + } + + // Set up certificate verification + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr); + + // Load trusted CA certificates based on platform + const char* ca_cert_path = nullptr; +#ifdef _WIN32 + ca_cert_path = "C:\\path\\to\\ca-certificates.crt"; // Windows path (FIXME) +#elif __APPLE__ + ca_cert_path = "/etc/ssl/cert.pem"; // macOS path +#else + ca_cert_path = "/etc/ssl/certs/ca-certificates.crt"; // Linux path +#endif + + if (!SSL_CTX_load_verify_locations(ctx, ca_cert_path, nullptr)) { + fprintf(stderr, "%s: error loading CA certificates\n", __func__); + SSL_CTX_free(ctx); + return nullptr; + } + + // Create an SSL connection + auto bio = BIO_new_ssl_connect(ctx); + if (!bio) { + fprintf(stderr, "%s: error creating SSL connection\n", __func__); + SSL_CTX_free(ctx); + return nullptr; + } + + // Set the hostname + if (!BIO_set_conn_hostname(bio, host)) { + fprintf(stderr, "%s: unable to set connection hostname %s\n", __func__, host); + BIO_free_all(bio); + SSL_CTX_free(ctx); + return nullptr; + } + + // Construct the HTTP request + char request[1024]; + snprintf(request, sizeof(request), "GET /%s HTTP/1.1\r\nHost: %s\r\nAccept: */*\r\nUser-Agent: llama-client\r\nConnection: close\r\n", path, host); + + // Add Authorization header if user credentials are available + if (strlen(userpass) > 0) { + char auth_header[256]; + snprintf(auth_header, sizeof(auth_header), "Authorization: Basic %s\r\n", userpass); + strcat(request, auth_header); + } + + // End of headers + strcat(request, "\r\n"); + + // Send the request + fprintf(stdout, "%s: downloading model from https://%s/%s to %s ...\n", __func__, host, path, path_model); + if (!BIO_puts(bio, request)) { + fprintf(stderr, "%s: error sending HTTP request https://%s/%s\n", __func__, host, path); + BIO_free_all(bio); + SSL_CTX_free(ctx); + return nullptr; + } + + // Read the response status line + char status_line[256]; + if (BIO_gets(bio, status_line, sizeof(status_line)) <= 0) { + fprintf(stderr, "%s: error reading response status line\n", __func__); + BIO_free_all(bio); + SSL_CTX_free(ctx); + return nullptr; + } + + // Verify HTTP status code + if (strncmp(status_line, "HTTP/1.1 200", 12) != 0) { + fprintf(stderr, "%s: HTTP request failed: %s\n", __func__, status_line); + BIO_free_all(bio); + SSL_CTX_free(ctx); + return nullptr; + } + + // Skip response headers + char buffer[4096]; + int n_bytes_received; + while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) { + // Look for the end of headers (empty line) + if (strstr(buffer, "\r\n\r\n")) { + break; + } + } + + // Read and save the file content + FILE* outfile = fopen(path_model, "wb"); + if (!outfile) { + fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path_model); + BIO_free_all(bio); + SSL_CTX_free(ctx); + return nullptr; + } + + int n_bytes_received_total = 0; + while ((n_bytes_received = BIO_read(bio, buffer, sizeof(buffer))) > 0) { + fwrite(buffer, 1, n_bytes_received, outfile); + n_bytes_received_total += n_bytes_received; + if (n_bytes_received_total % (1024 * 1024) == 0) { + fprintf(stdout, "%s: model downloading %dGi %s ...\n", __func__, n_bytes_received_total / 1024 / 1024, path_model); + } + } + fclose(outfile); + + // Clean up + BIO_free_all(bio); + SSL_CTX_free(ctx); + fprintf(stdout, "%s: model downloaded from https://%s/%s to %s.\n", __func__, host, path, path_model); + + return llama_load_model_from_file(path_model, params); +#else + LLAMA_LOG_ERROR("llama.cpp built without SSL support, downloading from url not supported.\n", __func__); + return nullptr; +#endif +} + std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = nullptr; + if (!params.model_url.empty()) { + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + } else { + model = llama_load_model_from_file(params.model.c_str(), mparams); + } if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return std::make_tuple(nullptr, nullptr); diff --git a/common/common.h b/common/common.h index 687f3425e8544..b9b59211254f2 100644 --- a/common/common.h +++ b/common/common.h @@ -17,6 +17,12 @@ #include #include +#ifdef HAVE_OPENSSL +#include +#include +#include +#endif + #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' #else @@ -89,6 +95,7 @@ struct gpt_params { struct llama_sampling_params sparams; std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model_url = ""; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string prompt = ""; @@ -191,6 +198,9 @@ std::tuple llama_init_from_gpt_par struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, + struct llama_model_params params); + // Batch utils void llama_batch_clear(struct llama_batch & batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 895d608fdcc06..5e1020009cbf1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2195,6 +2195,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co } printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); + printf(" -u MODEL_URL, --url MODEL_URL\n"); + printf(" model url (default: %s)\n", params.model_url.c_str()); printf(" -a ALIAS, --alias ALIAS\n"); printf(" set an alias for the model, will be added as `model` field in completion response\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -2317,6 +2319,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } params.model = argv[i]; + } else if (arg == "-u" || arg == "--model-url") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_url = argv[i]; } else if (arg == "-a" || arg == "--alias") { if (++i >= argc) { invalid_param = true;