Skip to content

Commit

Permalink
Improve progress bar
Browse files Browse the repository at this point in the history
Set default width to whatever the terminal is. Also fixed a small bug around
default n_gpu_layers value.

Signed-off-by: Eric Curtin <[email protected]>
  • Loading branch information
ericcurtin committed Dec 14, 2024
1 parent 56eea07 commit 474206d
Showing 1 changed file with 136 additions and 73 deletions.
209 changes: 136 additions & 73 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#if defined(_WIN32)
# include <windows.h>
#else
# include <sys/ioctl.h>
# include <unistd.h>
#endif

Expand Down Expand Up @@ -29,7 +30,6 @@
class Opt {
public:
int init(int argc, const char ** argv) {
construct_help_str_();
// Parse arguments
if (parse(argc, argv)) {
printe("Error: Failed to parse arguments.\n");
Expand All @@ -48,14 +48,54 @@ class Opt {

std::string model_;
std::string user_;
int context_size_ = 2048, ngl_ = -1;
int context_size_ = -1, ngl_ = -1;
bool verbose_ = false;

private:
std::string help_str_;
bool help_ = false;

void construct_help_str_() {
help_str_ =
int parse(int argc, const char ** argv) {
int positional_args_i = 0;
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) {
if (i + 1 >= argc) {
return 1;
}

context_size_ = std::atoi(argv[++i]);
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) {
if (i + 1 >= argc) {
return 1;
}

ngl_ = std::atoi(argv[++i]);
} else if (strcmp(argv[i], "-v") == 0 || strcmp(argv[i], "--verbose") == 0 ||
strcmp(argv[i], "--log-verbose") == 0) {
verbose_ = true;
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
help_ = true;
return 0;
} else if (!positional_args_i) {
if (!argv[i][0] || argv[i][0] == '-') {
return 1;
}

++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
user_ = argv[i];
} else {
user_ += " " + std::string(argv[i]);
}
}

return model_.empty(); // model_ is the only required value
}

// -v, --verbose, --log-verbose
void help() const {
printf(
"Description:\n"
" Runs a llm\n"
"\n"
Expand All @@ -64,15 +104,9 @@ class Opt {
"\n"
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: " +
std::to_string(context_size_);
help_str_ +=
")\n"
" Context size (default: %d)\n"
" -n, --ngl <value>\n"
" Number of GPU layers (default: " +
std::to_string(ngl_);
help_str_ +=
")\n"
" Number of GPU layers (default: %d)\n"
" -h, --help\n"
" Show help message\n"
"\n"
Expand All @@ -96,43 +130,10 @@ class Opt {
" llama-run https://example.com/some-file1.gguf\n"
" llama-run some-file2.gguf\n"
" llama-run file://some-file3.gguf\n"
" llama-run --ngl 99 some-file4.gguf\n"
" llama-run --ngl 99 some-file5.gguf Hello World\n";
" llama-run --ngl 999 some-file4.gguf\n"
" llama-run --ngl 999 some-file5.gguf Hello World\n",
llama_context_default_params().n_batch, llama_model_default_params().n_gpu_layers);
}

int parse(int argc, const char ** argv) {
int positional_args_i = 0;
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) {
if (i + 1 >= argc) {
return 1;
}

context_size_ = std::atoi(argv[++i]);
} else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) {
if (i + 1 >= argc) {
return 1;
}

ngl_ = std::atoi(argv[++i]);
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
help_ = true;
return 0;
} else if (!positional_args_i) {
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
user_ = argv[i];
} else {
user_ += " " + std::string(argv[i]);
}
}

return model_.empty(); // model_ is the only required value
}

void help() const { printf("%s", help_str_.c_str()); }
};

struct progress_data {
Expand All @@ -151,6 +152,18 @@ struct FileDeleter {

typedef std::unique_ptr<FILE, FileDeleter> FILE_ptr;

static int get_terminal_width() {
#if defined(_WIN32)
CONSOLE_SCREEN_BUFFER_INFO csbi;
GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi);
return csbi.srWindow.Right - csbi.srWindow.Left + 1;
#else
struct winsize w;
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
return w.ws_col;
#endif
}

#ifdef LLAMA_USE_CURL
class CurlWrapper {
public:
Expand Down Expand Up @@ -270,9 +283,9 @@ class CurlWrapper {

static std::string human_readable_size(curl_off_t size) {
static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
char length = sizeof(suffix) / sizeof(suffix[0]);
int i = 0;
double dbl_size = size;
char length = sizeof(suffix) / sizeof(suffix[0]);
int i = 0;
double dbl_size = size;
if (size > 1024) {
for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
dbl_size = size / 1024.0;
Expand All @@ -293,27 +306,75 @@ class CurlWrapper {

total_to_download += data->file_size;
const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
const curl_off_t percentage = (now_downloaded_plus_file_size * 100) / total_to_download;
const curl_off_t pos = (percentage / 5);
const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download);
std::string progress_prefix = generate_progress_prefix(percentage);

const double speed = calculate_speed(now_downloaded, data->start_time);
const double tim = (total_to_download - now_downloaded) / speed;
std::string progress_suffix =
generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim);

const int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix);
std::string progress_bar;
for (int i = 0; i < 20; ++i) {
progress_bar.append((i < pos) ? "" : " ");
}
generate_progress_bar(progress_bar_width, percentage, progress_bar);

// Calculate download speed and estimated time to completion
const auto now = std::chrono::steady_clock::now();
const std::chrono::duration<double> elapsed_seconds = now - data->start_time;
const double speed = now_downloaded / elapsed_seconds.count();
const double estimated_time = (total_to_download - now_downloaded) / speed;
printe("\r%ld%% |%s| %s/%s %.2f MB/s %s ", percentage, progress_bar.c_str(),
human_readable_size(now_downloaded).c_str(), human_readable_size(total_to_download).c_str(),
speed / (1024 * 1024), human_readable_time(estimated_time).c_str());
fflush(stderr);
print_progress(progress_prefix, progress_bar, progress_suffix);
data->printed = true;

return 0;
}

static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) {
return (now_downloaded_plus_file_size * 100) / total_to_download;
}

static std::string generate_progress_prefix(curl_off_t percentage) {
std::ostringstream progress_output;
progress_output << percentage << "% |";
return progress_output.str();
}

static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
const auto now = std::chrono::steady_clock::now();
const std::chrono::duration<double> elapsed_seconds = now - start_time;
return now_downloaded / elapsed_seconds.count();
}

static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download,
double speed, double estimated_time) {
std::ostringstream progress_output;
progress_output << human_readable_size(now_downloaded_plus_file_size).c_str() << "/"
<< human_readable_size(total_to_download).c_str() << " " << std::fixed << std::setprecision(2)
<< speed / (1024 * 1024) << " MB/s " << human_readable_time(estimated_time).c_str();
return progress_output.str();
}

static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) {
int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 5;
if (progress_bar_width < 10) {
progress_bar_width = 10;
}
return progress_bar_width;
}

static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage,
std::string & progress_bar) {
const curl_off_t pos = (percentage * progress_bar_width) / 100;
for (int i = 0; i < progress_bar_width; ++i) {
progress_bar.append((i < pos) ? "" : " ");
}

return progress_bar;
}

static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
const std::string & progress_suffix) {
std::ostringstream progress_output;
progress_output << progress_prefix << progress_bar << "| " << progress_suffix;
printe("\r%*s\r%s", get_terminal_width(), " ", progress_output.str().c_str());
fflush(stderr);
}

// Function to write data to a file
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
FILE * out = static_cast<FILE *>(stream);
Expand Down Expand Up @@ -467,6 +528,7 @@ class LlamaData {
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
resolve_model(opt.model_);
printe("Loading model");
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
if (!model) {
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
Expand All @@ -478,8 +540,7 @@ class LlamaData {
// Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) {
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx;
ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch;
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
if (!context) {
printe("%s: error: failed to create the llama_context\n", __func__);
Expand Down Expand Up @@ -642,8 +703,9 @@ static int handle_user_input(std::string & user_input, const std::string & user_
}

printf(
"\r "
"\r\033[32m> \033[0m");
"\r%*s"
"\r\033[32m> \033[0m",
get_terminal_width(), " ");
return read_user_input(user_input); // Returns true if input ends the loop
}

Expand Down Expand Up @@ -682,8 +744,9 @@ static int chat_loop(LlamaData & llama_data, const std::string & user_) {
return 0;
}

static void log_callback(const enum ggml_log_level level, const char * text, void *) {
if (level == GGML_LOG_LEVEL_ERROR) {
static void log_callback(const enum ggml_log_level level, const char * text, void * p) {
const Opt * opt = static_cast<Opt *>(p);
if (opt->verbose_ || level == GGML_LOG_LEVEL_ERROR) {
printe("%s", text);
}
}
Expand Down Expand Up @@ -721,7 +784,7 @@ int main(int argc, const char ** argv) {
opt.user_ += read_pipe_data();
}

llama_log_set(log_callback, nullptr);
llama_log_set(log_callback, &opt);
LlamaData llama_data;
if (llama_data.init(opt)) {
return 1;
Expand Down

0 comments on commit 474206d

Please sign in to comment.