Skip to content

Commit

Permalink
common : Changed tuple to struct (TODO fix) (ggerganov#8823)
Browse files Browse the repository at this point in the history
* common : Changed tuple to struct (TODO fix)

Use struct `llama_init_result` to replace the previous
std::tuple<struct llama_model *, struct llama_context *>

* delete llama_init_default_params()

* delete the extra whitespace
  • Loading branch information
Septa2112 authored and arthw committed Aug 7, 2024
1 parent cbc870a commit bc42ec9
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 59 deletions.
18 changes: 10 additions & 8 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2039,8 +2039,8 @@ std::string fs_get_cache_file(const std::string & filename) {
//
// Model utils
//

std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_init_result iparams;
auto mparams = llama_model_params_from_gpt_params(params);

llama_model * model = nullptr;
Expand All @@ -2055,7 +2055,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par

if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
return iparams;
}

auto cparams = llama_context_params_from_gpt_params(params);
Expand All @@ -2064,7 +2064,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
return iparams;
}

if (!params.control_vectors.empty()) {
Expand All @@ -2075,7 +2075,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
return iparams;
}

int err = llama_control_vector_apply(lctx,
Expand All @@ -2087,7 +2087,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (err) {
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
return iparams;
}
}

Expand All @@ -2099,7 +2099,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
llama_free(lctx);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
return iparams;
}
llama_lora_adapter_set(lctx, adapter, lora_scale);
}
Expand Down Expand Up @@ -2135,7 +2135,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
llama_reset_timings(lctx);
}

return std::make_tuple(model, lctx);
iparams.model = model;
iparams.context = lctx;
return iparams;
}

struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
Expand Down
8 changes: 6 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,12 @@ std::string fs_get_cache_file(const std::string & filename);
// Model utils
//

// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_init_result {
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
};

struct llama_init_result llama_init_from_gpt_params(gpt_params & params);

struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
Expand Down
7 changes: 4 additions & 3 deletions examples/cvector-generator/cvector-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);

// load the model to get hparams
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

// int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_n_layer(model);
Expand Down
8 changes: 4 additions & 4 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model;
llama_context * ctx;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
Expand Down
7 changes: 4 additions & 3 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ int main(int argc, char ** argv) {
params.warmup = false;

// init
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
Expand Down
6 changes: 3 additions & 3 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,10 @@ int main(int argc, char ** argv) {
params.warmup = false;

// init
llama_model * model;
llama_context * ctx;
llama_init_result llama_init = llama_init_from_gpt_params(params);

std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
Expand Down
5 changes: 4 additions & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ int main(int argc, char ** argv) {

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

model = llama_init.model;
ctx = llama_init.context;

if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
Expand Down
8 changes: 4 additions & 4 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the target model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

// Tokenize the prompt
std::vector<llama_token> inp;
Expand Down
8 changes: 4 additions & 4 deletions examples/lookup/lookup-create.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ int main(int argc, char ** argv){
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
GGML_ASSERT(model != nullptr);

// tokenize the prompt
Expand Down
8 changes: 4 additions & 4 deletions examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ int main(int argc, char ** argv){
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

// tokenize the prompt
std::vector<llama_token> inp;
Expand Down
8 changes: 4 additions & 4 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ int main(int argc, char ** argv){
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

// tokenize the prompt
std::vector<llama_token> inp;
Expand Down
5 changes: 4 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ int main(int argc, char ** argv) {

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

model = llama_init.model;
ctx = llama_init.context;
if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
Expand Down
8 changes: 4 additions & 4 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model = NULL;
llama_context * ctx = NULL;

// load the target model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

// load the prompts from an external file if there are any
if (params.prompt.empty()) {
Expand Down
8 changes: 4 additions & 4 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2018,11 +2018,11 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model;
llama_context * ctx;

// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
Expand Down
9 changes: 5 additions & 4 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,12 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);

llama_model * model;
llama_context * ctx;

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
Expand Down
7 changes: 4 additions & 3 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ int main(int argc, char ** argv) {
std::string result2;

// init
llama_model * model;
llama_context * ctx;
llama_init_result llama_init = llama_init_from_gpt_params(params);

llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;

std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
Expand Down
5 changes: 4 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,10 @@ struct server_context {
// dedicate one sequence to the system prompt
params.n_parallel += 1;

std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_init_result llama_init = llama_init_from_gpt_params(params);

model = llama_init.model;
ctx = llama_init.context;
params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}});
Expand Down
8 changes: 6 additions & 2 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;

// load the target model
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
llama_init_result llama_init_tgt = llama_init_from_gpt_params(params);
model_tgt = llama_init_tgt.model;
ctx_tgt = llama_init_tgt.context;

// load the draft model
params.model = params.model_draft;
Expand All @@ -75,7 +77,9 @@ int main(int argc, char ** argv) {
params.n_threads = params.n_threads_draft;
}
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
llama_init_result llama_init_dft = llama_init_from_gpt_params(params);
model_dft = llama_init_dft.model;
ctx_dft = llama_init_dft.context;

const bool vocab_type_tgt = llama_vocab_type(model_tgt);
LOG("vocab_type tgt: %d\n", vocab_type_tgt);
Expand Down

0 comments on commit bc42ec9

Please sign in to comment.