Skip to content

Commit

Permalink
control vector support in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
vgel committed Mar 10, 2024
1 parent 7ec24b4 commit c82301c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
70 changes: 70 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,35 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.lora_base = argv[i];
} else if (arg == "--control-vector") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back(std::make_tuple(argv[i], 1.0f));
} else if (arg == "--control-vector-scaled") {
if (++i >= argc) {
invalid_param = true;
break;
}
const char * control_vector = argv[i];
if (++i >= argc) {
invalid_param = true;
break;
}
params.control_vectors.push_back(std::make_tuple(control_vector, std::stof(argv[i])));
} else if (arg == "--control-vector-layer-range") {
if (++i >= argc) {
invalid_param = true;
break;
}
int32_t start = std::stoi(argv[i]);
if (++i >= argc) {
invalid_param = true;
break;
}
int32_t end = std::stoi(argv[i]);
params.control_vector_layer_range = std::make_tuple(start, end);
} else if (arg == "--mmproj") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1087,6 +1116,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
printf(" --control-vector FNAME\n");
printf(" add a control vector\n");
printf(" --control-vector-scaled FNAME S\n");
printf(" add a control vector with user defined scaling S\n");
printf(" --control-vector-layer-range START END\n");
printf(" layer range to apply the control vector(s) to, start and end inclusive\n");
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
Expand Down Expand Up @@ -1351,6 +1386,41 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr);
}

if (!params.control_vectors.empty()) {
int32_t layer_start, layer_end;
std::tie(layer_start, layer_end) = params.control_vector_layer_range;

if (layer_start == 0) layer_start = 1;
if (layer_end == 0) layer_end = 31;

struct llama_control_vector * vector = nullptr;

for (const auto& t : params.control_vectors) {
std::string path;
float strength;
std::tie(path, strength) = t;

fprintf(stderr, "%s: loading control vector from %s\n", __func__, path.c_str());
struct llama_control_vector * temp = llama_control_vector_load(path.c_str());
if (temp == nullptr) {
fprintf(stderr, "%s: error: failed to load control vector from %s\n", __func__, path.c_str());
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
llama_control_vector_scale(temp, strength);

if (vector == nullptr) {
vector = temp;
} else {
llama_control_vector_add(vector, temp);
llama_control_vector_free(temp);
}
}

llama_apply_control_vector(lctx, vector, layer_start, layer_end);
}

for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
Expand Down
3 changes: 3 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct gpt_params {
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter

std::vector<std::tuple<std::string, float>> control_vectors; // control vector with user defined scale
std::tuple<int32_t, int32_t> control_vector_layer_range; // layer range for control vector

int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting)
Expand Down

0 comments on commit c82301c

Please sign in to comment.