diff --git a/common/sampling.cpp b/common/sampling.cpp index 739f629339653..4ceb7918aa551 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -31,10 +31,8 @@ std::string gpt_sampling_params::print_samplers() const { return result; } -struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { - struct llama_sampling_context * result = new llama_sampling_context(); - - result->params = params; +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { + struct llama_sampling * result = nullptr; { auto lparams = llama_sampling_default_params(); @@ -66,35 +64,25 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m lparams.samplers[i] = params.samplers[i]; } - result->smpl = llama_sampling_init(model, lparams); + result = llama_sampling_init(model, lparams); - llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root"); - llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data()); + llama_sampling_set_grammar (result, params.grammar.c_str(), "root"); + llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data()); } return result; } -void llama_sampling_free(struct llama_sampling_context * ctx) { - llama_sampling_free(ctx->smpl); - - delete ctx; -} - -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { - if (dst->smpl) { - llama_sampling_free(dst->smpl); +void llama_sampling_cp(llama_sampling * src, llama_sampling * dst) { + if (dst) { + llama_sampling_free(dst); } - dst->smpl = llama_sampling_cp(src->smpl); -} - -llama_token llama_sampling_last(llama_sampling_context * ctx) { - return llama_sampling_prev(ctx->smpl, 0); + dst = llama_sampling_cp(src); } -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { - n = std::min(n, llama_sampling_n_prev(ctx_sampling->smpl)); +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) { + n = std::min(n, llama_sampling_n_prev(smpl)); if (n <= 0) { return ""; @@ -104,7 +92,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab for (int i = n - 1; i >= 0; i--) { - const llama_token id = llama_sampling_prev(ctx_sampling->smpl, i); + const llama_token id = llama_sampling_prev(smpl, i); GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); @@ -206,14 +194,14 @@ std::vector llama_sampling_types_from_chars(const std::strin } llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, + struct llama_sampling * smpl, + struct llama_context * ctx, int idx) { - llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx)); + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampling_get_candidates(ctx_sampling->smpl); + auto * cur_p = llama_sampling_get_candidates(smpl); - llama_sampling_grammar(ctx_sampling->smpl, cur_p); + llama_sampling_grammar(smpl, cur_p); - return llama_sampling_sample(ctx_sampling->smpl, cur_p); + return llama_sampling_sample(smpl, cur_p); } diff --git a/common/sampling.h b/common/sampling.h index 90141897784aa..8a2e595ba2f97 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -7,7 +7,7 @@ // sampling parameters typedef struct gpt_sampling_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. @@ -30,7 +30,7 @@ typedef struct gpt_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector samplers = { + std::vector samplers = { LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TFS_Z, LLAMA_SAMPLER_TYPE_TYPICAL_P, @@ -50,36 +50,21 @@ typedef struct gpt_sampling_params { std::string print_samplers() const; } gpt_sampling_params; -// general sampler context -// TODO: move to llama.h -struct llama_sampling_context { - // parameters that will be used for sampling - gpt_sampling_params params; +// overload of llama_sampling_init using gpt_sampling_params +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); - llama_sampling * smpl; -}; +void llama_sampling_cp(llama_sampling * src, llama_sampling * dst); -// Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); +// get a string representation of the last accepted tokens +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); -void llama_sampling_free(struct llama_sampling_context * ctx); +char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type); +std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type); -// Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); - -// Get the last accepted token -llama_token llama_sampling_last(llama_sampling_context * ctx); - -// Get a string representation of the last accepted tokens -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); - -char llama_sampling_type_to_chr(llama_sampler_type sampler_type); -std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); - -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & names_string); +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & names_string); llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, + struct llama_sampling * smpl, + struct llama_context * ctx, int idx = -1); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 80dcbb5fb6b8f..2f661f00ebd69 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling_context ** g_ctx_sampling; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -167,11 +167,11 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling_context * ctx_sampling = nullptr; + llama_sampling * smpl = nullptr; g_model = &model; g_ctx = &ctx; - g_ctx_sampling = &ctx_sampling; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -345,7 +345,7 @@ int main(int argc, char ** argv) { std::vector embd; - ctx_sampling = llama_sampling_init(model, sparams); + smpl = llama_sampling_init(model, sparams); while (n_remain != 0 || params.interactive) { // predict @@ -417,11 +417,11 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx); + const llama_token id = llama_sampling_sample(smpl, ctx); - llama_sampling_accept(ctx_sampling->smpl, id, true); + llama_sampling_accept(smpl, id, true); - // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -440,7 +440,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], false); + llama_sampling_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -472,7 +472,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); @@ -538,7 +538,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + else if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found EOS token\n"); if (params.interactive) { @@ -611,7 +611,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling->smpl); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -634,13 +634,13 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx, ctx_sampling->smpl); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 51d9df071920c..774b7b779d4cd 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n return true; } -static const char * sample(struct llama_sampling_context * ctx_sampling, +static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama); - llama_sampling_accept(ctx_sampling->smpl, id, true); + const llama_token id = llama_sampling_sample(smpl, ctx_llama); + llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams); - if (!ctx_sampling) { + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); printf("\n"); } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 59107ea0d936c..c056cadc63ae2 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e LOG_TEE("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct llama_sampling_context * ctx_sampling, +static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama); - llama_sampling_accept(ctx_sampling->smpl, id, true); + const llama_token id = llama_sampling_sample(smpl, ctx_llama); + llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri return ctx_llava; } -static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams); - return ctx_sampling; + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){ - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; } @@ -278,12 +278,12 @@ int main(int argc, char ** argv) { if (!params.prompt.empty()) { LOG_TEE("%s\n", params.prompt.c_str()); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; bool have_tmp = false; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0){ if(!have_tmp)continue; @@ -296,18 +296,18 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); }else { while (true) { LOG_TEE(""); std::string prompt; std::getline(std::cin, prompt); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -315,7 +315,7 @@ int main(int argc, char ** argv) { if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); } } printf("\n"); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 2ddceb1e3b4f3..2bd31d00268a2 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -117,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(model, params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); // verification n-grams std::vector ngrams_cur(G); @@ -158,9 +158,9 @@ int main(int argc, char ** argv) { // sample first token { - id = llama_sampling_sample(ctx_sampling, ctx, 0); + id = llama_sampling_sample(smpl, ctx, 0); - llama_sampling_accept(ctx_sampling->smpl, id, true); + llama_sampling_accept(smpl, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -283,9 +283,9 @@ int main(int argc, char ** argv) { } // sample the next token - id = llama_sampling_sample(ctx_sampling, ctx, i_batch); + id = llama_sampling_sample(smpl, ctx, i_batch); - llama_sampling_accept(ctx_sampling->smpl, id, true); + llama_sampling_accept(smpl, id, true); // print { @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -467,10 +467,10 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx, ctx_sampling->smpl); + llama_print_timings(ctx, smpl); llama_kv_cache_view_free(&kvc_view); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 91a750ad0bf6d..da4d57a518754 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(model, params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); std::vector draft; @@ -128,9 +128,9 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx, i_dft); + llama_token id = llama_sampling_sample(smpl, ctx, i_dft); - llama_sampling_accept(ctx_sampling->smpl, id, true); + llama_sampling_accept(smpl, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -239,9 +239,9 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx, ctx_sampling->smpl); + llama_print_timings(ctx, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch_tgt); llama_free(ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4e7456d2cd7ad..834d2279834b7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling_context ** g_ctx_sampling; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -193,13 +193,13 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling_context * ctx_sampling = nullptr; + llama_sampling * smpl = nullptr; std::vector chat_msgs; g_model = &model; g_ctx = &ctx; - g_ctx_sampling = &ctx_sampling; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -494,8 +494,8 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - ctx_sampling = llama_sampling_init(model, sparams); - if (!ctx_sampling) { + smpl = llama_sampling_init(model, sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } @@ -650,11 +650,11 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx); + const llama_token id = llama_sampling_sample(smpl, ctx); - llama_sampling_accept(ctx_sampling->smpl, id, /* apply_grammar= */ true); + llama_sampling_accept(smpl, id, /* apply_grammar= */ true); - // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -673,7 +673,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], /* apply_grammar= */ false); + llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -716,7 +716,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); + const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -738,7 +738,7 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = llama_sampling_last(ctx_sampling); + llama_token last_token = llama_sampling_last(smpl); for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { @@ -755,7 +755,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found an EOG token\n"); if (params.interactive) { @@ -776,7 +776,7 @@ int main(int argc, char ** argv) { // if current token is not EOG, we add it to current assistant message if (params.conversation) { - auto id = llama_sampling_last(ctx_sampling); + auto id = llama_sampling_last(smpl); assistant_ss << llama_token_to_piece(ctx, id, false); } @@ -872,7 +872,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling->smpl); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -897,13 +897,13 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx, ctx_sampling->smpl); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index e7bf819a04b78..7ce982a92e497 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -50,8 +50,8 @@ static std::vector k_prompts = { struct client { ~client() { - if (ctx_sampling) { - llama_sampling_free(ctx_sampling); + if (smpl) { + llama_sampling_free(smpl); } } @@ -72,7 +72,7 @@ struct client { std::string prompt; std::string response; - struct llama_sampling_context * ctx_sampling = nullptr; + struct llama_sampling * smpl = nullptr; }; static void print_date_time() { @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(model, params.sparams); + client.smpl = llama_sampling_init(model, params.sparams); } std::vector tokens_system; @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - llama_sampling_reset(client.ctx_sampling->smpl); + llama_sampling_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, client.i_batch - i); + const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling->smpl, id, true); + llama_sampling_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3e749c8846f60..edcd5cbe4ea3b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -176,7 +176,7 @@ struct server_slot { struct gpt_sampling_params sparams; llama_token sampled; - llama_sampling_context * ctx_sampling = nullptr; + llama_sampling * smpl = nullptr; int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor @@ -668,8 +668,8 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } } @@ -1054,12 +1054,12 @@ struct server_context { } { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } - slot.ctx_sampling = llama_sampling_init(model, slot.sparams); - if (slot.ctx_sampling == nullptr) { + slot.smpl = llama_sampling_init(model, slot.sparams); + if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; @@ -2098,7 +2098,7 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.ctx_sampling->smpl); + llama_sampling_reset(slot.smpl); if (!slot.params.cache_prompt) { slot.n_past_se = 0; @@ -2111,7 +2111,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.ctx_sampling->smpl, slot.cache_tokens[i], false); + llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2164,7 +2164,7 @@ struct server_context { slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling->smpl); + llama_sampling_reset(slot.smpl); } // remove the non-common part from the cache @@ -2341,9 +2341,9 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, slot.i_batch - i); + const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.ctx_sampling->smpl, id, true); + llama_sampling_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2354,7 +2354,7 @@ struct server_context { result.tok = id; - const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling->smpl); + const auto * cur_p = llama_sampling_get_candidates(slot.smpl); for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { result.probs.push_back({ diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 7e4bcfb9b703c..e78927936ab52 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -21,7 +21,7 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct llama_sampling_context * ctx_sampling; + struct llama_sampling * smpl; }; int main(int argc, char ** argv) { @@ -176,7 +176,7 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context (reuse the llama_context's sampling instance) - struct llama_sampling_context * ctx_sampling = llama_sampling_init(model_tgt, params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams); // draft sequence data std::vector drafts(n_seq_dft); @@ -187,7 +187,7 @@ int main(int argc, char ** argv) { for (int s = 0; s < n_seq_dft; ++s) { // allocate llama_sampling for each draft sequence - drafts[s].ctx_sampling = llama_sampling_init(model_dft, params.sparams); + drafts[s].smpl = llama_sampling_init(model_dft, params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -230,12 +230,12 @@ int main(int argc, char ** argv) { if (params.sparams.temp > 0) { // stochastic verification - llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); - auto & dist_tgt = *llama_sampling_get_candidates(ctx_sampling->smpl); + auto & dist_tgt = *llama_sampling_get_candidates(smpl); - llama_sampling_grammar(ctx_sampling->smpl, &dist_tgt); - llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt); + llama_sampling_grammar(smpl, &dist_tgt); + llama_sampling_softmax(smpl, &dist_tgt); float p_tgt = 0.0f; float p_dft = 0.0f; @@ -280,7 +280,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(ctx_sampling->smpl, token_id, true); + llama_sampling_accept(smpl, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -334,8 +334,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sampling_sample_dist(ctx_sampling->smpl, &dist_tgt); - llama_sampling_accept(ctx_sampling->smpl, token_id, true); + token_id = llama_sampling_sample_dist(smpl, &dist_tgt); + llama_sampling_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -344,11 +344,11 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(ctx_sampling->smpl, token_id, true); + llama_sampling_accept(smpl, token_id, true); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); token_str = llama_token_to_piece(ctx_tgt, token_id); @@ -436,7 +436,7 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + llama_sampling_cp(smpl, drafts[0].smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -465,9 +465,9 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, drafts[s].i_batch_dft); + llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); - const auto * cur_p = llama_sampling_get_candidates(drafts[s].ctx_sampling->smpl); + const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl); for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", @@ -505,7 +505,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl); sa.push_back(n_seq_cur); @@ -521,7 +521,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling->smpl, id, true); + llama_sampling_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists @@ -600,11 +600,11 @@ int main(int argc, char ** argv) { llama_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt, ctx_sampling->smpl); + llama_print_timings(ctx_tgt, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - llama_sampling_free(drafts[s].ctx_sampling); + llama_sampling_free(drafts[s].smpl); } llama_batch_free(batch_dft); diff --git a/include/llama.h b/include/llama.h index e808f59f2e14c..d6897e06d43de 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1126,6 +1126,11 @@ extern "C" { const struct llama_sampling * smpl, int32_t ith); + /// @details Get the last accepted token + /// Same as llama_sampling_prev(smpl, 0) + /// returns LLAMA_TOKEN_NULL if there are no accepted tokens + LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); + /// @details Get the number of accepted tokens (max of n_prev) LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); diff --git a/src/llama.cpp b/src/llama.cpp index 76754835c00ce..564aa6db41958 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20344,6 +20344,10 @@ llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) return llama_sampling_prev_impl(*smpl, ith); } +llama_token llama_sampling_last(const struct llama_sampling * smpl) { + return llama_sampling_prev_impl(*smpl, 0); +} + int llama_sampling_n_prev(const struct llama_sampling * smpl) { return llama_sampling_n_prev_impl(*smpl); }