From 0e6d170a506e951f716afee816d1388b75102887 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 14:16:21 +0300 Subject: [PATCH] sampling : avoid llama_model in few samplers ggml-ci --- common/sampling.cpp | 8 ++- include/llama.h | 11 ++-- src/llama-sampling.cpp | 138 +++++++++++++++++++++++++++++++---------- src/llama-sampling.h | 21 ------- src/llama.cpp | 25 -------- 5 files changed, 117 insertions(+), 86 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5f27d5006044f..c81b4d233b04e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -152,13 +152,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - model, + llama_n_vocab(model), params.logit_bias.size(), params.logit_bias.data())); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties( - model, + llama_n_vocab (model), + llama_token_eos(model), + llama_token_nl (model), params.penalty_last_n, params.penalty_repeat, params.penalty_freq, @@ -196,7 +198,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/include/llama.h b/include/llama.h index 5441d98f05f28..8bfb9e3b1532b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1003,12 +1003,13 @@ extern "C" { // // sample from the logits of the last token in the batch // const llama_token id = llama_sampler_sample(smpl, ctx, -1); // + // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) + // llama_sampler_accept(smpl, id); // ... // } // // llama_sampler_free(smpl); // - // // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // @@ -1086,7 +1087,7 @@ extern "C" { /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( - const struct llama_model * model, + int32_t n_vocab, uint32_t seed, float tau, float eta, @@ -1108,7 +1109,9 @@ extern "C" { const char * grammar_root); LLAMA_API struct llama_sampler * llama_sampler_init_penalties( - const struct llama_model * model, + int32_t n_vocab, // llama_n_vocab() + llama_token special_eos_id, // llama_token_eos() + llama_token linefeed_id, // llama_token_nl() int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat, // 1.0 = disabled float penalty_freq, // 0.0 = disabled @@ -1117,7 +1120,7 @@ extern "C" { bool ignore_eos); // ignore the end-of-sequence token LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( - const struct llama_model * model, + int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 15d5b5f8a44c8..e53b3d3a77edc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -3,6 +3,7 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include #include #include #include @@ -926,7 +927,7 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa // mirostat struct llama_sampler_mirostat { - const struct llama_vocab * vocab; + const int32_t n_vocab; const uint32_t seed; @@ -964,7 +965,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); llama_sampler_softmax_impl(cur_p); @@ -986,25 +987,25 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) { +struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { - /* .vocab = */ &vocab, - /* .seed = */ seed, - /* .tau = */ tau, - /* .eta = */ eta, - /* .m = */ m, - /* .mu = */ 2.0f*tau, - /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -1172,7 +1173,9 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab // penalties struct llama_sampler_penalties { - const struct llama_vocab * vocab; + const int32_t n_vocab; + const llama_token special_eos_id; + const llama_token linefeed_id; const int32_t penalty_last_n; const float penalty_repeat; @@ -1194,10 +1197,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = { /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary"); - if (ctx->ignore_eos) { - cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY; + assert(ctx->special_eos_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { + cur_p->data[ctx->special_eos_id].logit = -INFINITY; + } else { + // else, search for the special EOS token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->special_eos_id) { + cur_p->data[i].logit = -INFINITY; + break; + } + } + } } if ((ctx->penalty_last_n == 0) || @@ -1205,7 +1219,29 @@ static struct llama_sampler_i llama_sampler_penalties_i = { return; } - const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY; + bool nl_found = false; + size_t nl_idx = 0; + float nl_logit = -INFINITY; + if (!ctx->penalize_nl) { + assert(ctx->linefeed_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = ctx->linefeed_id; + nl_logit = cur_p->data[ctx->linefeed_id].logit; + } else { + // else, search for the linefeed token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = i; + nl_logit = cur_p->data[i].logit; + break; + } + } + } + } // Create a frequency map to count occurrences of each token in last_tokens // TODO: optimize this by maintaining the token count in the sampler context @@ -1216,9 +1252,9 @@ static struct llama_sampler_i llama_sampler_penalties_i = { llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); - if (!ctx->penalize_nl) { + if (!ctx->penalize_nl && nl_found) { // restore the logit of the newline token if it was penalized - cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit; + cur_p->data[nl_idx].logit = nl_logit; } }, /* .reset = */ [](struct llama_sampler * smpl) { @@ -1227,8 +1263,10 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; - auto * result = llama_sampler_init_penalties_impl( - *ctx_src->vocab, + auto * result = llama_sampler_init_penalties( + ctx_src->n_vocab, + ctx_src->special_eos_id, + ctx_src->linefeed_id, ctx_src->penalty_last_n, ctx_src->penalty_repeat, ctx_src->penalty_freq, @@ -1246,14 +1284,30 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }, }; -struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { - GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); - GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); +struct llama_sampler * llama_sampler_init_penalties( + int32_t n_vocab, + llama_token special_eos_id, + llama_token linefeed_id, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + if (linefeed_id == LLAMA_TOKEN_NULL) { + penalize_nl = false; + } + + if (special_eos_id == LLAMA_TOKEN_NULL) { + ignore_eos = true; + } return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { - /* .vocab = */ &vocab, + /* .n_vocab = */ n_vocab, + /* .special_eos_id = */ special_eos_id, + /* .linefeed_id = */ linefeed_id, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, @@ -1268,9 +1322,11 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca // logit-bias struct llama_sampler_logit_bias { - const struct llama_vocab * vocab; + const int32_t n_vocab; + + const std::vector logit_bias; - std::vector logit_bias; + std::vector to_search; }; static struct llama_sampler_i llama_sampler_logit_bias_i = { @@ -1279,31 +1335,47 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary"); + ctx->to_search.clear(); + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) for (const auto & lb : ctx->logit_bias) { - cur_p->data[lb.token].logit += lb.bias; + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); + } + } + + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; + } + } } }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_logit_bias_impl( - const struct llama_vocab & vocab, +struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { return new llama_sampler { /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .vocab = */ &vocab, + /* .n_vocab = */ n_vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, }, }; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index ddc84a3900666..137c0025ce0d8 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -33,28 +33,7 @@ void llama_sampler_penalties_impl( float penalty_freq, float penalty_present); -struct llama_sampler * llama_sampler_init_mirostat_impl( - const struct llama_vocab & vocab, - uint32_t seed, - float tau, - float eta, - int32_t m); - struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); - -struct llama_sampler * llama_sampler_init_penalties_impl( - const struct llama_vocab & vocab, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos); - -LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( - const struct llama_vocab & vocab, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); diff --git a/src/llama.cpp b/src/llama.cpp index c67f3638d337d..6bbaf9fc9bae7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20592,36 +20592,11 @@ int32_t llama_chat_apply_template( // sampling // -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { - return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); -} - // TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_penalties( - const struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); -} - -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_logit_bias( - const struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { - return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); -} - // // model split //