Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

common : fix parallel shard download interleaving output #6831

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
135 changes: 129 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
#define LLAMA_CURL_MAX_HEADER_LENGTH 256
#define LLAMA_PROGRESS_UPDATE_INTERVAL 1
#define LLAMA_PROGRESS_PERCENTAGE_WIDTH 10
#define LLAMA_DEFAULT_CONSOLE_WIDTH 80
#endif // LLAMA_USE_CURL

using json = nlohmann::ordered_json;
Expand Down Expand Up @@ -1866,7 +1869,87 @@ void llama_batch_add(

#ifdef LLAMA_USE_CURL

static bool llama_download_file(CURL * curl, const char * url, const char * path) {
struct shard_file_progress {
std::string filename;
double total_bytes;
double received_bytes;
};

std::map<std::string, shard_file_progress> progress_table;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be done without global variables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it something easily possible you think @TevinWang ?

std::mutex progress_mutex;
std::stringstream download_done_buffer;

static int shard_progress_callback(void* clientp, double dltotal, double dlnow, double ultotal, double ulnow) {
// upload not needed for downloading
(void) ultotal;
(void) ulnow;
char* url = static_cast<char*>(clientp);

std::lock_guard<std::mutex> lock(progress_mutex);

shard_file_progress& progress = progress_table[url];
progress.total_bytes = static_cast<double>(dltotal);
progress.received_bytes = static_cast<double>(dlnow);

std::string url_string = static_cast<std::string>(url);
progress.filename = url_string.substr(url_string.find_last_of('/') + 1);

return 0;
}

// function to get the console width
static int get_console_width() {
#ifdef _WIN32
CONSOLE_SCREEN_BUFFER_INFO csbi;
GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
return csbi.dwSize.X;
#elif defined(__linux__) || defined(__APPLE__)
struct winsize ws;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &ws);
return ws.ws_col;
#else
return LLAMA_DEFAULT_CONSOLE_WIDTH; // Default value
#endif
}

static void print_shard_progress_table(bool first_progress) {
if (first_progress) {
fprintf(stderr, "=========================\n");
} else {
// use updating output
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (unsigned int i = 0; i < progress_table.size(); i++) {
fprintf(stderr, "\033[1A\033[K\033[1A\033[K");
}
fprintf(stderr, "\r");
}
}


int progress_bar_width = get_console_width() - LLAMA_PROGRESS_PERCENTAGE_WIDTH;

// Print the progress information for each downloading file
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
shard_file_progress progress = entry.second;
int progress_width = static_cast<int>((progress.received_bytes / progress.total_bytes) * progress_bar_width);

fprintf(stderr, "%s\n", progress.filename.c_str());
fprintf(stderr, "[");
for (int i = 0; i < progress_width; ++i) {
fprintf(stderr, "=");
}
for (int i = progress_width; i < progress_bar_width; ++i) {
fprintf(stderr, " ");
}
fprintf(stderr, "] %d%%\n", static_cast<int>((progress.received_bytes / progress.total_bytes) * 100));
}
}
}

static bool llama_download_file(CURL * curl, const char * url, const char * path, bool is_shard) {
bool force_download = false;

// Set the URL, allow to follow http redirection
Expand Down Expand Up @@ -2001,6 +2084,12 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path

// display download progress
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);

// custom progress callback on sharded download
if (is_shard) {
curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, shard_progress_callback);
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, url);
}

// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
Expand Down Expand Up @@ -2046,7 +2135,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (etag_file) {
fputs(headers.etag, etag_file);
fclose(etag_file);
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
if (is_shard) {
download_done_buffer << __func__ << ": file etag saved " << etag_path << ": " << headers.etag << "\n";
} else {
fprintf(stderr, "%s: file etag saved %s: %s\n", __func__, etag_path, headers.etag);
}
}
}

Expand All @@ -2056,8 +2149,11 @@ static bool llama_download_file(CURL * curl, const char * url, const char * path
if (last_modified_file) {
fputs(headers.last_modified, last_modified_file);
fclose(last_modified_file);
fprintf(stderr, "%s: file last modified saved %s: %s\n", __func__, last_modified_path,
headers.last_modified);
if (is_shard) {
download_done_buffer << __func__ << ": unable to rename file: " << path_temporary << " to " << path << "\n";
} else {
fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary, path);
}
}
}

Expand Down Expand Up @@ -2089,7 +2185,7 @@ struct llama_model * llama_load_model_from_url(
return NULL;
}

if (!llama_download_file(curl, model_url, path_model)) {
if (!llama_download_file(curl, model_url, path_model, false)) {
return NULL;
}

Expand Down Expand Up @@ -2148,12 +2244,39 @@ struct llama_model * llama_load_model_from_url(
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);

auto * curl = curl_easy_init();
bool res = llama_download_file(curl, split_url, split_path);
bool res = llama_download_file(curl, split_url, split_path, true);
curl_easy_cleanup(curl);

return res;
}, idx));
}

bool first_progress = true;
while (true) {
// Print the progress table periodically
std::this_thread::sleep_for(std::chrono::seconds(LLAMA_PROGRESS_UPDATE_INTERVAL));
// Print the progress table header
print_shard_progress_table(first_progress);
first_progress = false;

// Check if all downloads are complete
bool all_complete = true;
{
std::lock_guard<std::mutex> lock(progress_mutex);
for (const auto& entry : progress_table) {
const shard_file_progress& progress = entry.second;
if (progress.received_bytes < progress.total_bytes) {
all_complete = false;
break;
}
}
}

if (all_complete) {
fprintf(stderr, "%s", download_done_buffer.str().c_str());
break;
}
}

// Wait for all downloads to complete
for (auto & f : futures_download) {
Expand Down
Loading