From c82301c79c31c9cebc6bc826ff1b903a1c765844 Mon Sep 17 00:00:00 2001 From: Theia Vogel Date: Sat, 9 Mar 2024 20:22:51 -0800 Subject: [PATCH] control vector support in cli --- common/common.cpp | 70 +++++++++++++++++++++++++++++++++++++++++++++++ common/common.h | 3 ++ 2 files changed, 73 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index 16ef4d7f74dd99..6f8d49cf1259ee 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -562,6 +562,35 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.lora_base = argv[i]; + } else if (arg == "--control-vector") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back(std::make_tuple(argv[i], 1.0f)); + } else if (arg == "--control-vector-scaled") { + if (++i >= argc) { + invalid_param = true; + break; + } + const char * control_vector = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back(std::make_tuple(control_vector, std::stof(argv[i]))); + } else if (arg == "--control-vector-layer-range") { + if (++i >= argc) { + invalid_param = true; + break; + } + int32_t start = std::stoi(argv[i]); + if (++i >= argc) { + invalid_param = true; + break; + } + int32_t end = std::stoi(argv[i]); + params.control_vector_layer_range = std::make_tuple(start, end); } else if (arg == "--mmproj") { if (++i >= argc) { invalid_param = true; @@ -1087,6 +1116,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); + printf(" --control-vector FNAME\n"); + printf(" add a control vector\n"); + printf(" --control-vector-scaled FNAME S\n"); + printf(" add a control vector with user defined scaling S\n"); + printf(" --control-vector-layer-range START END\n"); + printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); printf(" -md FNAME, --model-draft FNAME\n"); @@ -1351,6 +1386,41 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(nullptr, nullptr); } + if (!params.control_vectors.empty()) { + int32_t layer_start, layer_end; + std::tie(layer_start, layer_end) = params.control_vector_layer_range; + + if (layer_start == 0) layer_start = 1; + if (layer_end == 0) layer_end = 31; + + struct llama_control_vector * vector = nullptr; + + for (const auto& t : params.control_vectors) { + std::string path; + float strength; + std::tie(path, strength) = t; + + fprintf(stderr, "%s: loading control vector from %s\n", __func__, path.c_str()); + struct llama_control_vector * temp = llama_control_vector_load(path.c_str()); + if (temp == nullptr) { + fprintf(stderr, "%s: error: failed to load control vector from %s\n", __func__, path.c_str()); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + llama_control_vector_scale(temp, strength); + + if (vector == nullptr) { + vector = temp; + } else { + llama_control_vector_add(vector, temp); + llama_control_vector_free(temp); + } + } + + llama_apply_control_vector(lctx, vector, layer_start, layer_end); + } + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); diff --git a/common/common.h b/common/common.h index f8d82b8713c871..28f7ccccfa393e 100644 --- a/common/common.h +++ b/common/common.h @@ -102,6 +102,9 @@ struct gpt_params { std::vector> lora_adapter; // lora adapter path with user defined scale std::string lora_base = ""; // base model path for the lora adapter + std::vector> control_vectors; // control vector with user defined scale + std::tuple control_vector_layer_range; // layer range for control vector + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line // (which is more convenient to use for plotting)