From 1454a5235276c50c3f531bbee9bec410eb973387 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 25 Nov 2024 21:13:36 -0500 Subject: [PATCH] Opt class for positional argument handling Added support for positional arguments `MODEL` and `PROMPT`. Signed-off-by: Eric Curtin --- examples/run/run.cpp | 187 +++++++++++++++++++++---------------------- 1 file changed, 92 insertions(+), 95 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index cac2faefcc256f..4fcbd666896086 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -4,109 +4,115 @@ #include #endif -#include #include #include #include #include #include -#include #include #include "llama-cpp.h" typedef std::unique_ptr char_array_ptr; -struct Argument { - std::string flag; - std::string help_text; -}; - -struct Options { - std::string model_path, prompt_non_interactive; - int ngl = 99; - int n_ctx = 2048; -}; +class Opt { + public: + int init_opt(int argc, const char ** argv) { + construct_help_str_(); + // Parse arguments + if (parse(argc, argv)) { + fprintf(stderr, "Error: Failed to parse arguments.\n"); + help(); + return 1; + } -class ArgumentParser { - public: - ArgumentParser(const char * program_name) : program_name(program_name) {} + // If help is requested, show help and exit + if (help_) { + help(); + return 2; + } - void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") { - string_args[flag] = &var; - arguments.push_back({flag, help_text}); + return 0; // Success } - void add_argument(const std::string & flag, int & var, const std::string & help_text = "") { - int_args[flag] = &var; - arguments.push_back({flag, help_text}); + const char * model_ = nullptr; + std::string prompt_; + int context_size_ = 2048, ngl_ = 0; + + private: + std::string help_str_; + bool help_ = false; + + void construct_help_str_() { + help_str_ = + "Description:\n" + " Runs a llm\n" + "\n" + "Usage:\n" + " llama-run [options] MODEL [PROMPT]\n" + "\n" + "Options:\n" + " -c, --context-size \n" + " Context size (default: " + + std::to_string(context_size_); + help_str_ += + ")\n" + " -n, --ngl \n" + " Number of GPU layers (default: " + + std::to_string(ngl_); + help_str_ += + ")\n" + " -h, --help\n" + " Show help message\n" + "\n" + "Examples:\n" + " llama-run your_model.gguf\n" + " llama-run --ngl 99 your_model.gguf\n" + " llama-run --ngl 99 your_model.gguf\n"; } int parse(int argc, const char ** argv) { + if (parse_arguments(argc, argv) || !model_) { + return 1; + } + + return 0; + } + + int parse_arguments(int argc, const char ** argv) { + int positional_args_i = 0; for (int i = 1; i < argc; ++i) { - std::string arg = argv[i]; - if (string_args.count(arg)) { - if (i + 1 < argc) { - *string_args[arg] = argv[++i]; - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) { + if (i + 1 >= argc) { return 1; } - } else if (int_args.count(arg)) { - if (i + 1 < argc) { - if (parse_int_arg(argv[++i], *int_args[arg]) != 0) { - fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]); - print_usage(); - return 1; - } - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + + context_size_ = std::atoi(argv[++i]); + } else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) { + if (i + 1 >= argc) { return 1; } + + ngl_ = std::atoi(argv[++i]); + } else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) { + help_ = true; + model_ = 1; + return 0; + } else if (!positional_args_i) { + ++positional_args_i; + model_ = argv[i]; + } else if (positional_args_i == 1) { + ++positional_args_i; + prompt_ = argv[i]; } else { - fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str()); - print_usage(); - return 1; + prompt_ += " " + std::string(argv[i]); } } - if (string_args["-m"]->empty()) { - fprintf(stderr, "error: -m is required\n"); - print_usage(); - return 1; - } - return 0; } - private: - const char * program_name; - std::unordered_map string_args; - std::unordered_map int_args; - std::vector arguments; - - int parse_int_arg(const char * arg, int & value) { - char * end; - const long val = std::strtol(arg, &end, 10); - if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { - value = static_cast(val); - return 0; - } - return 1; - } - - void print_usage() const { - printf("\nUsage:\n"); - printf(" %s [OPTIONS]\n\n", program_name); - printf("Options:\n"); - for (const auto & arg : arguments) { - printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str()); - } - - printf("\n"); - } + void help() const { printf("%s", help_str_.c_str()); } }; class LlamaData { @@ -116,13 +122,13 @@ class LlamaData { llama_context_ptr context; std::vector messages; - int init(const Options & opt) { - model = initialize_model(opt.model_path, opt.ngl); + int init(const Opt & opt) { + model = initialize_model(opt.model_, opt.ngl_); if (!model) { return 1; } - context = initialize_context(model, opt.n_ctx); + context = initialize_context(model, opt.context_size_); if (!context) { return 1; } @@ -134,6 +140,7 @@ class LlamaData { private: // Initializes the model and returns a unique pointer to it llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { + ggml_backend_load_all(); llama_model_params model_params = llama_model_default_params(); model_params.n_gpu_layers = ngl; @@ -273,19 +280,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 0; } -static int parse_arguments(const int argc, const char ** argv, Options & opt) { - ArgumentParser parser(argv[0]); - parser.add_argument("-m", opt.model_path, "model"); - parser.add_argument("-p", opt.prompt_non_interactive, "prompt"); - parser.add_argument("-c", opt.n_ctx, "context_size"); - parser.add_argument("-ngl", opt.ngl, "n_gpu_layers"); - if (parser.parse(argc, argv)) { - return 1; - } - - return 0; -} - static int read_user_input(std::string & user) { std::getline(std::cin, user); return user.empty(); // Indicate an error or empty input @@ -382,17 +376,20 @@ static std::string read_pipe_data() { } int main(int argc, const char ** argv) { - Options opt; - if (parse_arguments(argc, argv, opt)) { + Opt opt; + const int opt_ret = opt.init_opt(argc, argv); + if (opt_ret == 2) { + return 0; + } else if (opt_ret) { return 1; } if (!is_stdin_a_terminal()) { - if (!opt.prompt_non_interactive.empty()) { - opt.prompt_non_interactive += "\n\n"; + if (!opt.prompt_.empty()) { + opt.prompt_ += "\n\n"; } - opt.prompt_non_interactive += read_pipe_data(); + opt.prompt_ += read_pipe_data(); } llama_log_set(log_callback, nullptr); @@ -401,7 +398,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.prompt_non_interactive)) { + if (chat_loop(llama_data, opt.prompt_)) { return 1; }