From 86e7299ef5dff0f388922dc6fcbce009e99d8005 Mon Sep 17 00:00:00 2001 From: "Derrick T. Woolworth" Date: Sat, 6 Jul 2024 15:32:04 -0500 Subject: [PATCH] added support for Authorization Bearer tokens when downloading model (#8307) * added support for Authorization Bearer tokens * removed auth_token, removed set_ function, other small fixes * Update common/common.cpp --------- Co-authored-by: Xuan Son Nguyen --- common/common.cpp | 44 +++++++++++++++++++++++++++++++++++++------- common/common.h | 6 ++++-- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c548bcb2857a8..fc0f3b350ec6e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -190,6 +190,12 @@ int32_t cpu_get_num_math() { // CLI argument parsing // +void gpt_params_handle_hf_token(gpt_params & params) { + if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { + params.hf_token = std::getenv("HF_TOKEN"); + } +} + void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { gpt_params_handle_model_default(params); + gpt_params_handle_hf_token(params); + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.model_url = argv[i]; return true; } + if (arg == "-hft" || arg == "--hf-token") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.hf_token = argv[i]; + return true; + } if (arg == "-hfr" || arg == "--hf-repo") { CHECK_ARG params.hf_repo = argv[i]; @@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); + options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); options.push_back({ "retrieval" }); options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); @@ -2015,9 +2032,9 @@ std::tuple llama_init_from_gpt_par llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } @@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) == 0; } -static bool llama_download_file(const std::string & url, const std::string & path) { +static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); @@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + // Check if hf-token or bearer-token was specified + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer "; + auth_header += hf_token.c_str(); + struct curl_slist *http_headers = NULL; + http_headers = curl_slist_append(http_headers, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + } + #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // operating system. Currently implemented under MS-Windows. @@ -2415,6 +2441,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat struct llama_model * llama_load_model_from_url( const char * model_url, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // Basic validation of the model_url if (!model_url || strlen(model_url) == 0) { @@ -2422,7 +2449,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - if (!llama_download_file(model_url, path_model)) { + if (!llama_download_file(model_url, path_model, hf_token)) { return NULL; } @@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url( // Prepare download in parallel std::vector> futures_download; for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool { + futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { char split_path[PATH_MAX] = {0}; llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - return llama_download_file(split_url, split_path); + return llama_download_file(split_url, split_path, hf_token); }, idx)); } @@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf( const char * repo, const char * model, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // construct hugging face model url: // @@ -2511,7 +2539,7 @@ struct llama_model * llama_load_model_from_hf( model_url += "/resolve/main/"; model_url += model; - return llama_load_model_from_url(model_url.c_str(), path_model, params); + return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params); } #else @@ -2519,6 +2547,7 @@ struct llama_model * llama_load_model_from_hf( struct llama_model * llama_load_model_from_url( const char * /*model_url*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; @@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf( const char * /*repo*/, const char * /*model*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; diff --git a/common/common.h b/common/common.h index dabf02b598b3b..184a53dc09064 100644 --- a/common/common.h +++ b/common/common.h @@ -108,6 +108,7 @@ struct gpt_params { std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download + std::string hf_token = ""; // HF token std::string hf_repo = ""; // HF repo std::string hf_file = ""; // HF file std::string prompt = ""; @@ -256,6 +257,7 @@ struct gpt_params { bool spm_infill = false; // suffix/prefix/middle pattern for infill }; +void gpt_params_handle_hf_token(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); @@ -311,8 +313,8 @@ std::tuple llama_init_from_gpt_par struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); // Batch utils