From aa450bea4b475caa3430907b6f76fbea40e943e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Sat, 16 Mar 2024 11:24:06 -0400 Subject: [PATCH] Add flag to track found arguments. --- common/common.cpp | 124 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 122 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 7cf732ddfe0c9..1b0ba849398de 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -151,7 +151,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { std::replace(arg.begin(), arg.end(), '_', '-'); } + bool arg_found = false; if (arg == "-s" || arg == "--seed") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -159,6 +161,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.seed = std::stoul(argv[i]); } if (arg == "-t" || arg == "--threads") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -169,6 +172,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-tb" || arg == "--threads-batch") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -179,6 +183,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-td" || arg == "--threads-draft") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -189,6 +194,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-tbd" || arg == "--threads-batch-draft") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -199,6 +205,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-p" || arg == "--prompt") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -206,9 +213,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.prompt = argv[i]; } if (arg == "-e" || arg == "--escape") { + arg_found = true; params.escape = true; } if (arg == "--prompt-cache") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -216,12 +225,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.path_prompt_cache = argv[i]; } if (arg == "--prompt-cache-all") { + arg_found = true; params.prompt_cache_all = true; } if (arg == "--prompt-cache-ro") { + arg_found = true; params.prompt_cache_ro = true; } if (arg == "-bf" || arg == "--binary-file") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -240,6 +252,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), argv[i]); } if (arg == "-f" || arg == "--file") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -258,6 +271,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-n" || arg == "--n-predict") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -265,6 +279,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_predict = std::stoi(argv[i]); } if (arg == "--top-k") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -272,6 +287,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.top_k = std::stoi(argv[i]); } if (arg == "-c" || arg == "--ctx-size") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -279,6 +295,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_ctx = std::stoi(argv[i]); } if (arg == "--grp-attn-n" || arg == "-gan") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -287,6 +304,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.grp_attn_n = std::stoi(argv[i]); } if (arg == "--grp-attn-w" || arg == "-gaw") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -295,6 +313,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.grp_attn_w = std::stoi(argv[i]); } if (arg == "--rope-freq-base") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -302,6 +321,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.rope_freq_base = std::stof(argv[i]); } if (arg == "--rope-freq-scale") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -309,6 +329,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.rope_freq_scale = std::stof(argv[i]); } if (arg == "--rope-scaling") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -320,6 +341,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { else { invalid_param = true; break; } } if (arg == "--rope-scale") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -327,6 +349,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.rope_freq_scale = 1.0f/std::stof(argv[i]); } if (arg == "--yarn-orig-ctx") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -334,6 +357,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_orig_ctx = std::stoi(argv[i]); } if (arg == "--yarn-ext-factor") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -341,6 +365,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_ext_factor = std::stof(argv[i]); } if (arg == "--yarn-attn-factor") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -348,6 +373,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_attn_factor = std::stof(argv[i]); } if (arg == "--yarn-beta-fast") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -355,6 +381,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_beta_fast = std::stof(argv[i]); } if (arg == "--yarn-beta-slow") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -362,6 +389,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.yarn_beta_slow = std::stof(argv[i]); } if (arg == "--pooling") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -373,6 +401,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { else { invalid_param = true; break; } } if (arg == "--defrag-thold" || arg == "-dt") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -380,6 +409,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.defrag_thold = std::stof(argv[i]); } if (arg == "--samplers") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -388,6 +418,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); } if (arg == "--sampling-seq") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -395,6 +426,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.samplers_sequence = sampler_types_from_chars(argv[i]); } if (arg == "--top-p") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -402,6 +434,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.top_p = std::stof(argv[i]); } if (arg == "--min-p") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -409,6 +442,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.min_p = std::stof(argv[i]); } if (arg == "--temp") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -417,6 +451,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.temp = std::max(sparams.temp, 0.0f); } if (arg == "--tfs") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -424,6 +459,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.tfs_z = std::stof(argv[i]); } if (arg == "--typical") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -431,6 +467,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.typical_p = std::stof(argv[i]); } if (arg == "--repeat-last-n") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -439,6 +476,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); } if (arg == "--repeat-penalty") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -446,6 +484,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.penalty_repeat = std::stof(argv[i]); } if (arg == "--frequency-penalty") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -453,6 +492,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.penalty_freq = std::stof(argv[i]); } if (arg == "--presence-penalty") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -460,6 +500,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.penalty_present = std::stof(argv[i]); } if (arg == "--dynatemp-range") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -467,6 +508,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.dynatemp_range = std::stof(argv[i]); } if (arg == "--dynatemp-exp") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -474,6 +516,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.dynatemp_exponent = std::stof(argv[i]); } if (arg == "--mirostat") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -481,6 +524,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.mirostat = std::stoi(argv[i]); } if (arg == "--mirostat-lr") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -488,6 +532,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.mirostat_eta = std::stof(argv[i]); } if (arg == "--mirostat-ent") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -495,6 +540,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.mirostat_tau = std::stof(argv[i]); } if (arg == "--cfg-negative-prompt") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -502,6 +548,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.cfg_negative_prompt = argv[i]; } if (arg == "--cfg-negative-prompt-file") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -518,6 +565,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "--cfg-scale") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -525,6 +573,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.cfg_scale = std::stof(argv[i]); } if (arg == "-b" || arg == "--batch-size") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -532,6 +581,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_batch = std::stoi(argv[i]); } if (arg == "-ub" || arg == "--ubatch-size") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -539,6 +589,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_ubatch = std::stoi(argv[i]); } if (arg == "--keep") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -546,6 +597,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_keep = std::stoi(argv[i]); } if (arg == "--draft") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -553,6 +605,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_draft = std::stoi(argv[i]); } if (arg == "--chunks") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -560,6 +613,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_chunks = std::stoi(argv[i]); } if (arg == "-np" || arg == "--parallel") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -567,6 +621,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_parallel = std::stoi(argv[i]); } if (arg == "-ns" || arg == "--sequences") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -574,6 +629,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_sequences = std::stoi(argv[i]); } if (arg == "--p-split" || arg == "-ps") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -581,6 +637,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.p_split = std::stof(argv[i]); } if (arg == "-m" || arg == "--model") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -588,6 +645,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.model = argv[i]; } if (arg == "-md" || arg == "--model-draft") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -595,6 +653,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.model_draft = argv[i]; } if (arg == "-a" || arg == "--alias") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -602,6 +661,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.model_alias = argv[i]; } if (arg == "--lora") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -610,6 +670,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.use_mmap = false; } if (arg == "--lora-scaled") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -623,6 +684,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.use_mmap = false; } if (arg == "--lora-base") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -630,6 +692,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.lora_base = argv[i]; } if (arg == "--control-vector") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -637,6 +700,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.control_vectors.push_back({ 1.0f, argv[i], }); } if (arg == "--control-vector-scaled") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -649,6 +713,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.control_vectors.push_back({ std::stof(argv[i]), fname, }); } if (arg == "--control-vector-layer-range") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -661,6 +726,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.control_vector_layer_end = std::stoi(argv[i]); } if (arg == "--mmproj") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -668,6 +734,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.mmproj = argv[i]; } if (arg == "--image") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -675,51 +742,67 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.image = argv[i]; } if (arg == "-i" || arg == "--interactive") { + arg_found = true; params.interactive = true; } if (arg == "--embedding") { + arg_found = true; params.embedding = true; } if (arg == "--interactive-first") { + arg_found = true; params.interactive_first = true; } if (arg == "-ins" || arg == "--instruct") { + arg_found = true; params.instruct = true; } if (arg == "-cml" || arg == "--chatml") { + arg_found = true; params.chatml = true; } if (arg == "--infill") { + arg_found = true; params.infill = true; } if (arg == "-dkvc" || arg == "--dump-kv-cache") { + arg_found = true; params.dump_kv_cache = true; } if (arg == "-nkvo" || arg == "--no-kv-offload") { + arg_found = true; params.no_kv_offload = true; } if (arg == "-ctk" || arg == "--cache-type-k") { + arg_found = true; params.cache_type_k = argv[++i]; } if (arg == "-ctv" || arg == "--cache-type-v") { + arg_found = true; params.cache_type_v = argv[++i]; } if (arg == "--multiline-input") { + arg_found = true; params.multiline_input = true; } if (arg == "--simple-io") { + arg_found = true; params.simple_io = true; } if (arg == "-cb" || arg == "--cont-batching") { + arg_found = true; params.cont_batching = true; } if (arg == "--color") { + arg_found = true; params.use_color = true; } if (arg == "--mlock") { + arg_found = true; params.use_mlock = true; } if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -731,6 +814,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -742,6 +826,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "--main-gpu" || arg == "-mg") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -752,6 +837,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { #endif // GGML_USE_CUBLAS_SYCL } if (arg == "--split-mode" || arg == "-sm") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -777,6 +863,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } if (arg == "--tensor-split" || arg == "-ts") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -803,9 +890,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { #endif // GGML_USE_CUBLAS_SYCL } if (arg == "--no-mmap") { + arg_found = true; params.use_mmap = false; } if (arg == "--numa") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -817,12 +906,15 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { else { invalid_param = true; break; } } if (arg == "--verbose-prompt") { + arg_found = true; params.verbose_prompt = true; } if (arg == "--no-display-prompt") { + arg_found = true; params.display_prompt = false; } if (arg == "-r" || arg == "--reverse-prompt") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -830,6 +922,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.antiprompt.emplace_back(argv[i]); } if (arg == "-ld" || arg == "--logdir") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -841,6 +934,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -848,9 +942,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.logits_file = argv[i]; } if (arg == "--perplexity" || arg == "--all-logits") { + arg_found = true; params.logits_all = true; } if (arg == "--ppl-stride") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -858,6 +954,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.ppl_stride = std::stoi(argv[i]); } if (arg == "-ptc" || arg == "--print-token-count") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -865,6 +962,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.n_print = std::stoi(argv[i]); } if (arg == "--ppl-output-type") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -872,9 +970,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.ppl_output_type = std::stoi(argv[i]); } if (arg == "--hellaswag") { + arg_found = true; params.hellaswag = true; } if (arg == "--hellaswag-tasks") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -882,9 +982,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.hellaswag_tasks = std::stoi(argv[i]); } if (arg == "--winogrande") { + arg_found = true; params.winogrande = true; } if (arg == "--winogrande-tasks") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -892,9 +994,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.winogrande_tasks = std::stoi(argv[i]); } if (arg == "--multiple-choice") { + arg_found = true; params.multiple_choice = true; } if (arg == "--multiple-choice-tasks") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -902,15 +1006,19 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.multiple_choice_tasks = std::stoi(argv[i]); } if (arg == "--kl-divergence") { + arg_found = true; params.kl_divergence = true; } if (arg == "--ignore-eos") { + arg_found = true; params.ignore_eos = true; } if (arg == "--no-penalize-nl") { + arg_found = true; sparams.penalize_nl = false; } if (arg == "-l" || arg == "--logit-bias") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -931,21 +1039,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { } } if (arg == "-h" || arg == "--help") { + arg_found = true; return false; - } if (arg == "--version") { + arg_found = true; fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); exit(0); } if (arg == "--random-prompt") { + arg_found = true; params.random_prompt = true; } if (arg == "--in-prefix-bos") { + arg_found = true; params.input_prefix_bos = true; } if (arg == "--in-prefix") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -953,6 +1065,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.input_prefix = argv[i]; } if (arg == "--in-suffix") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -960,6 +1073,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.input_suffix = argv[i]; } if (arg == "--grammar") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -967,6 +1081,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { sparams.grammar = argv[i]; } if (arg == "--grammar-file") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -984,6 +1099,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { ); } if (arg == "--override-kv") { + arg_found = true; if (++i >= argc) { invalid_param = true; break; @@ -1028,10 +1144,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { // Parse args for logging parameters } if ( log_param_single_parse( argv[i] ) ) { + arg_found = true; // Do nothing, log_param_single_parse automatically does it's thing // and returns if a match was found and parsed. } if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { + arg_found = true; // We have a matching known parameter requiring an argument, // now we need to check if there is anything after this argv // and flag invalid_param or parse it. @@ -1047,7 +1165,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { #endif // LOG_DISABLE_LOGS } - throw std::invalid_argument("error: unknown argument: " + arg); + if (!arg_found) { + throw std::invalid_argument("error: unknown argument: " + arg); + } } if (invalid_param) { throw std::invalid_argument("error: invalid parameter for argument: " + arg);