From c8f28909dcbdff6954993a76c8a8af3bf5d1a39e Mon Sep 17 00:00:00 2001 From: farbod Date: Fri, 6 Sep 2024 15:58:12 +0330 Subject: [PATCH] implements a retry function to avoid duplication --- common/common.cpp | 59 ++++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index cb33b95ea79df..fe992ddb51863 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2262,6 +2262,28 @@ static bool starts_with(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) == 0; } +static bool perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { + int remaining_attempts = max_attempts; + + while (remaining_attempts > 0) { + fprintf(stderr, "%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + fprintf(stderr, "%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + fprintf(stderr, "%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + return false; +} + static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl @@ -2365,20 +2387,9 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - int remaining_request_attempts = CURL_MAX_RETRY; - while (remaining_request_attempts > 0){ - CURLcode res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - int exponential_backoff_delay = std::pow(CURL_RETRY_DELAY_SECONDS, (CURL_MAX_RETRY - remaining_request_attempts)) * 1000; - fprintf(stderr, "%s: curl_easy_perform() failed: %s, retrying after %d miliseconds\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); - remaining_request_attempts--; - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else if (remaining_request_attempts <= 0) { - fprintf(stderr, "%s: curl_easy_perform() failed\n", __func__); - return false; - } else { - break; - } + bool was_perform_successful = perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; } long http_code = 0; @@ -2452,23 +2463,13 @@ static bool llama_download_file(const std::string & url, const std::string & pat }; // start the download - int remaining_attempts = CURL_MAX_RETRY; - while (remaining_attempts > 0){ - fprintf(stderr, "%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + fprintf(stderr, "%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - auto res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - int exponential_backoff_delay = std::pow(CURL_RETRY_DELAY_SECONDS, (CURL_MAX_RETRY - remaining_attempts)) * 1000; - fprintf(stderr, "\n%s: curl_easy_perform() failed: %s, retrying after %d miliseconnds\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); - remaining_attempts--; - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else if (remaining_attempts <= 0) { - fprintf(stderr, "%s: curl_easy_perform() failed\n", __func__); - return false; - } else { - break; - } + bool was_perform_successful = perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { + return false; } + long http_code = 0; curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);