diff --git a/common/common.cpp b/common/common.cpp index cb4d700f12840..2e8374d50cafa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -687,7 +687,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.lora_adapters.push_back({ std::string(argv[i]), 1.0, - nullptr, }); return true; } @@ -698,7 +697,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.lora_adapters.push_back({ lora_adapter, std::stof(argv[i]), - nullptr, }); return true; } @@ -2106,16 +2104,20 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - la.adapter = llama_lora_adapter_init(model, la.path.c_str()); - if (la.adapter == nullptr) { + llama_lora_adapter_container loaded_la; + loaded_la.path = la.path; + loaded_la.scale = la.scale; + loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); + if (loaded_la.adapter == nullptr) { fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); llama_free_model(model); return iparams; } + iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters } if (!params.lora_init_without_apply) { - llama_lora_adapters_apply(lctx, params.lora_adapters); + llama_lora_adapters_apply(lctx, iparams.lora_adapters); } if (params.ignore_eos) { diff --git a/common/common.h b/common/common.h index a0aa03ebbcc7b..d88966ece20aa 100644 --- a/common/common.h +++ b/common/common.h @@ -33,9 +33,12 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct llama_lora_adapter_container { +struct llama_lora_adapter_info { std::string path; float scale; +}; + +struct llama_lora_adapter_container : llama_lora_adapter_info { struct llama_lora_adapter * adapter; }; @@ -133,7 +136,7 @@ struct gpt_params { std::vector kv_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + std::vector lora_adapters; // lora adapter path with user defined scale std::vector control_vectors; // control vector with user defined scale @@ -315,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename); // struct llama_init_result { - struct llama_model * model = nullptr; + struct llama_model * model = nullptr; struct llama_context * context = nullptr; + std::vector lora_adapters; }; struct llama_init_result llama_init_from_gpt_params(gpt_params & params); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0467844de4fe4..898c83ea3522b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -623,6 +623,7 @@ struct server_response { struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; + std::vector lora_adapters; gpt_params params; @@ -682,6 +683,7 @@ struct server_context { model = llama_init.model; ctx = llama_init.context; + lora_adapters = llama_init.lora_adapters; params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params.model}}); @@ -1853,7 +1855,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SET_LORA: { - llama_lora_adapters_apply(ctx, params.lora_adapters); + llama_lora_adapters_apply(ctx, lora_adapters); server_task_result result; result.id = task.id; result.data = json{{ "success", true }}; @@ -3340,8 +3342,8 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json result = json::array(); - for (size_t i = 0; i < ctx_server.params.lora_adapters.size(); ++i) { - auto & la = ctx_server.params.lora_adapters[i]; + for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { + auto & la = ctx_server.lora_adapters[i]; result.push_back({ {"id", i}, {"path", la.path}, @@ -3356,10 +3358,10 @@ int main(int argc, char ** argv) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const std::vector body = json::parse(req.body); - int max_idx = ctx_server.params.lora_adapters.size(); + int max_idx = ctx_server.lora_adapters.size(); // clear existing value - for (auto & la : ctx_server.params.lora_adapters) { + for (auto & la : ctx_server.lora_adapters) { la.scale = 0.0f; } @@ -3368,7 +3370,7 @@ int main(int argc, char ** argv) { int id = entry.at("id"); float scale = entry.at("scale"); if (0 <= id && id < max_idx) { - ctx_server.params.lora_adapters[id].scale = scale; + ctx_server.lora_adapters[id].scale = scale; } else { throw std::runtime_error("invalid adapter id"); }