Skip to content

Commit

Permalink
clean up struct def
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Aug 5, 2024
1 parent 21cb133 commit c58a332
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
12 changes: 7 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_adapters.push_back({
std::string(argv[i]),
1.0,
nullptr,
});
return true;
}
Expand All @@ -698,7 +697,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.lora_adapters.push_back({
lora_adapter,
std::stof(argv[i]),
nullptr,
});
return true;
}
Expand Down Expand Up @@ -2106,16 +2104,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {

// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (la.adapter == nullptr) {
llama_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
return iparams;
}
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
llama_lora_adapters_apply(lctx, params.lora_adapters);
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}

if (params.ignore_eos) {
Expand Down
10 changes: 7 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@

#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"

struct llama_lora_adapter_container {
struct llama_lora_adapter_info {
std::string path;
float scale;
};

struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter;
};

Expand Down Expand Up @@ -133,7 +136,7 @@ struct gpt_params {
std::vector<llama_model_kv_override> kv_overrides;

bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_container> lora_adapters; // lora adapter path with user defined scale
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale

Expand Down Expand Up @@ -315,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
//

struct llama_init_result {
struct llama_model * model = nullptr;
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
};

struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
Expand Down
14 changes: 8 additions & 6 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ struct server_response {
struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;

gpt_params params;

Expand Down Expand Up @@ -682,6 +683,7 @@ struct server_context {

model = llama_init.model;
ctx = llama_init.context;
lora_adapters = llama_init.lora_adapters;
params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}});
Expand Down Expand Up @@ -1853,7 +1855,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
llama_lora_adapters_apply(ctx, params.lora_adapters);
llama_lora_adapters_apply(ctx, lora_adapters);
server_task_result result;
result.id = task.id;
result.data = json{{ "success", true }};
Expand Down Expand Up @@ -3340,8 +3342,8 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json result = json::array();
for (size_t i = 0; i < ctx_server.params.lora_adapters.size(); ++i) {
auto & la = ctx_server.params.lora_adapters[i];
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.lora_adapters[i];
result.push_back({
{"id", i},
{"path", la.path},
Expand All @@ -3356,10 +3358,10 @@ int main(int argc, char ** argv) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.params.lora_adapters.size();
int max_idx = ctx_server.lora_adapters.size();

// clear existing value
for (auto & la : ctx_server.params.lora_adapters) {
for (auto & la : ctx_server.lora_adapters) {
la.scale = 0.0f;
}

Expand All @@ -3368,7 +3370,7 @@ int main(int argc, char ** argv) {
int id = entry.at("id");
float scale = entry.at("scale");
if (0 <= id && id < max_idx) {
ctx_server.params.lora_adapters[id].scale = scale;
ctx_server.lora_adapters[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
Expand Down

0 comments on commit c58a332

Please sign in to comment.