diff --git a/common/common.cpp b/common/common.cpp index b6143e41c02cb..8deeff2a896e9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -68,6 +68,9 @@ #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 #define LLAMA_CURL_MAX_HEADER_LENGTH 256 +#define LLAMA_PROGRESS_UPDATE_INTERVAL 1 +#define LLAMA_PROGRESS_PERCENTAGE_WIDTH 10 +#define LLAMA_DEFAULT_CONSOLE_WIDTH 80 #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -1866,7 +1869,87 @@ void llama_batch_add( #ifdef LLAMA_USE_CURL -static bool llama_download_file(CURL * curl, const char * url, const char * path) { +struct shard_file_progress { + std::string filename; + double total_bytes; + double received_bytes; +}; + +std::map progress_table; +std::mutex progress_mutex; +std::stringstream download_done_buffer; + +static int shard_progress_callback(void* clientp, double dltotal, double dlnow, double ultotal, double ulnow) { + // upload not needed for downloading + (void) ultotal; + (void) ulnow; + char* url = static_cast(clientp); + + std::lock_guard lock(progress_mutex); + + shard_file_progress& progress = progress_table[url]; + progress.total_bytes = static_cast(dltotal); + progress.received_bytes = static_cast(dlnow); + + std::string url_string = static_cast(url); + progress.filename = url_string.substr(url_string.find_last_of('/') + 1); + + return 0; +} + +// function to get the console width +static int get_console_width() { +#ifdef _WIN32 + CONSOLE_SCREEN_BUFFER_INFO csbi; + GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); + return csbi.dwSize.X; +#elif defined(__linux__) || defined(__APPLE__) + struct winsize ws; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws); + return ws.ws_col; +#else + return LLAMA_DEFAULT_CONSOLE_WIDTH; // Default value +#endif +} + +static void print_shard_progress_table(bool first_progress) { + if (first_progress) { + fprintf(stderr, "=========================\n"); + } else { + // use updating output + { + std::lock_guard lock(progress_mutex); + for (unsigned int i = 0; i < progress_table.size(); i++) { + fprintf(stderr, "\033[1A\033[K\033[1A\033[K"); + } + fprintf(stderr, "\r"); + } + } + + + int progress_bar_width = get_console_width() - LLAMA_PROGRESS_PERCENTAGE_WIDTH; + + // Print the progress information for each downloading file + { + std::lock_guard lock(progress_mutex); + for (const auto& entry : progress_table) { + shard_file_progress progress = entry.second; + int progress_width = static_cast((progress.received_bytes / progress.total_bytes) * progress_bar_width); + + fprintf(stderr, "%s\n", progress.filename.c_str()); + fprintf(stderr, "["); + for (int i = 0; i < progress_width; ++i) { + fprintf(stderr, "="); + } + for (int i = progress_width; i < progress_bar_width; ++i) { + fprintf(stderr, " "); + } + fprintf(stderr, "] %d%%\n", static_cast((progress.received_bytes / progress.total_bytes) * 100)); + } + } +} + +static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) { bool force_download = false; // Set the URL, allow to follow http redirection @@ -2002,6 +2085,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path // display download progress curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + // custom progress callback on sharded download + if (is_shard) { + curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, shard_progress_callback); + curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, url); + } + // helper function to hide password in URL auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { std::size_t protocol_pos = url.find("://"); @@ -2046,7 +2135,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path if (etag_file) { fputs(headers.etag, etag_file); fclose(etag_file); - fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); + if (is_shard) { + download_done_buffer << __func__ << ": file etag saved " << etag_path << ": " << headers.etag << "\n"; + } else { + fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag); + } } } @@ -2056,8 +2149,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path if (last_modified_file) { fputs(headers.last_modified, last_modified_file); fclose(last_modified_file); - fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path, - headers.last_modified); + if (is_shard) { + download_done_buffer << __func__ << ": unable to rename file: " << path_temporary << " to " << path << "\n"; + } else { + fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path); + } } } @@ -2089,7 +2185,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - if (!llama_download_file(curl, model_url, path_model)) { + if (!llama_download_file(curl, model_url, path_model, false)) { return NULL; } @@ -2148,13 +2244,40 @@ struct llama_model * llama_load_model_from_url( llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); auto * curl = curl_easy_init(); - bool res = llama_download_file(curl, split_url, split_path); + bool res = llama_download_file(curl, split_url, split_path, true); curl_easy_cleanup(curl); return res; }, idx)); } + bool first_progress = true; + while (true) { + // Print the progress table periodically + std::this_thread::sleep_for(std::chrono::seconds(LLAMA_PROGRESS_UPDATE_INTERVAL)); + // Print the progress table header + print_shard_progress_table(first_progress); + first_progress = false; + + // Check if all downloads are complete + bool all_complete = true; + { + std::lock_guard lock(progress_mutex); + for (const auto& entry : progress_table) { + const shard_file_progress& progress = entry.second; + if (progress.received_bytes < progress.total_bytes) { + all_complete = false; + break; + } + } + } + + if (all_complete) { + fprintf(stderr, "%s", download_done_buffer.str().c_str()); + break; + } + } + // Wait for all downloads to complete for (auto & f : futures_download) { if (!f.get()) {