From 7a9bf68124e45ae42b3ddd4b4ae14d3485d78caa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 Aug 2024 14:32:54 +0300 Subject: [PATCH] cont : update sampling API ggml-ci --- common/sampling.cpp | 50 +++++-------------- common/sampling.h | 3 -- examples/batched/batched.cpp | 18 ++++--- examples/server/server.cpp | 7 +-- include/llama.h | 65 +++++------------------- src/llama-sampling.cpp | 87 ++++++++++++--------------------- src/llama-sampling.h | 55 ++++++++++++--------- src/llama-vocab.cpp | 12 ----- src/llama-vocab.h | 2 - src/llama.cpp | 95 +++++++++++++++++++----------------- tests/test-sampling.cpp | 39 ++++++--------- 11 files changed, 170 insertions(+), 263 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index d369f6c4a63e7..596bd24608916 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -29,6 +29,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa lp.mirostat = params.mirostat; lp.mirostat_tau = params.mirostat_tau; lp.mirostat_eta = params.mirostat_eta; + lp.cfg_scale = params.cfg_scale; lp.penalize_nl = params.penalize_nl; lp.ignore_eos = params.ignore_eos; @@ -36,7 +37,6 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa llama_sampling_set_rng_seed (result->smpl, params.seed); llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root"); - llama_sampling_set_cfg (result->smpl, params.cfg_negative_prompt.c_str(), params.cfg_scale); llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data()); } @@ -202,38 +202,21 @@ std::vector llama_sampling_types_from_chars(const std::strin // no reasons to expose this function in header static void sampler_queue( struct llama_sampling_context * ctx_sampling, - llama_token_data_array & cur_p, - size_t min_keep) { + llama_token_data_array & cur_p) { llama_sampling * smpl = ctx_sampling->smpl; const gpt_sampling_params & params = ctx_sampling->params; - const float temp = params.temp; - const float dynatemp_range = params.dynatemp_range; - const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; const std::vector & samplers_sequence = params.samplers_sequence; for (auto sampler_type : samplers_sequence) { switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break; - case llama_sampler_type::TEMPERATURE: - if (dynatemp_range > 0) { - float dynatemp_min = std::max(0.0f, temp - dynatemp_range); - float dynatemp_max = std::max(0.0f, temp + dynatemp_range); - llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); - } else { - llama_sampling_temp(smpl, &cur_p, temp); - } - break; + case llama_sampler_type::TOP_K: llama_sampling_top_k (smpl, &cur_p); break; + case llama_sampler_type::TFS_Z: llama_sampling_tail_free(smpl, &cur_p); break; + case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p); break; + case llama_sampler_type::TOP_P: llama_sampling_top_p (smpl, &cur_p); break; + case llama_sampler_type::MIN_P: llama_sampling_min_p (smpl, &cur_p); break; + case llama_sampler_type::TEMPERATURE: llama_sampling_temp (smpl, &cur_p); break; default : break; } } @@ -269,18 +252,11 @@ static llama_token llama_sampling_sample_impl( // greedy sampling, no probs id = llama_sampling_sample_greedy(smpl, &cur_p); } else { - if (mirostat == 1) { - const int mirostat_m = 100; - llama_sampling_temp(smpl, &cur_p, temp); - id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); - } else if (mirostat == 2) { - llama_sampling_temp(smpl, &cur_p, temp); - id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + if (mirostat != 0) { + llama_sampling_temp(smpl, &cur_p); + id = llama_sampling_sample_mirostat(smpl, &cur_p); } else { - // temperature sampling - size_t min_keep = std::max(1, params.min_keep); - - sampler_queue(ctx_sampling, cur_p, min_keep); + sampler_queue(ctx_sampling, cur_p); id = llama_sampling_sample(smpl, &cur_p); @@ -372,7 +348,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sampling_cfg(smpl, logits, logits_guidance, params.cfg_scale); + llama_sampling_cfg(smpl, logits, logits_guidance); } cur.resize(n_vocab); diff --git a/common/sampling.h b/common/sampling.h index 3e36e90ec6208..e31d8756d7c17 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -165,9 +165,6 @@ struct llama_sampling_context { // parameters that will be used for sampling gpt_sampling_params params; - // mirostat sampler state - float mirostat_mu; - llama_sampling * smpl; ring_buffer prev; diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 7e76bc9a8f518..f381d92075e4a 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,7 +64,13 @@ int main(int argc, char ** argv) { ctx_params.n_batch = std::max(n_predict, n_parallel); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + + auto sparams = llama_sampling_default_params(); + sparams.top_k = 40; + sparams.top_p = 0.9f; + sparams.temp = 0.4f; + + llama_sampling * smpl = llama_sampling_init(model, sparams); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -177,13 +183,9 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - const int top_k = 40; - const float top_p = 0.9f; - const float temp = 0.4f; - - llama_sampling_top_k(smpl, &candidates_p, top_k, 1); - llama_sampling_top_p(smpl, &candidates_p, top_p, 1); - llama_sampling_temp (smpl, &candidates_p, temp); + llama_sampling_top_k(smpl, &candidates_p); + llama_sampling_top_p(smpl, &candidates_p); + llama_sampling_temp (smpl, &candidates_p); const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 623b0074339bb..e4eedd23fe637 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2358,9 +2358,10 @@ struct server_context { const size_t n_valid = slot.ctx_sampling->n_valid; // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0); - } + // TODO: decide to how to handle this after the refactoring + //if (slot.sparams.temp == 0.0f && n_probs > n_valid) { + // llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0); + //} if (slot.sparams.temp == 0.0f) { // With greedy sampling the probabilities have possibly not been calculated. diff --git a/include/llama.h b/include/llama.h index 3d53347a6ba4f..cdd68d6a52031 100644 --- a/include/llama.h +++ b/include/llama.h @@ -382,6 +382,7 @@ extern "C" { int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau; // target entropy float mirostat_eta; // learning rate + float cfg_scale; // classifier-free guidance scale // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool penalize_nl; // consider newlines as a repeatable token @@ -1012,7 +1013,6 @@ extern "C" { // Sets the current rng seed. LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed); LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); - LLAMA_API void llama_sampling_set_cfg (struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale); LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. @@ -1023,50 +1023,32 @@ extern "C" { /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sampling_top_k( struct llama_sampling * smpl, - llama_token_data_array * candidates, - int32_t k, - size_t min_keep); + llama_token_data_array * candidates); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API void llama_sampling_top_p( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float p, - size_t min_keep); + llama_token_data_array * candidates); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 LLAMA_API void llama_sampling_min_p( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float p, - size_t min_keep); + llama_token_data_array * candidates); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sampling_tail_free( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float z, - size_t min_keep); + llama_token_data_array * candidates); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API void llama_sampling_typical( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API void llama_sampling_entropy( - struct llama_sampling * smpl, - llama_token_data_array * candidates_p, - float min_temp, - float max_temp, - float exponent_val); + llama_token_data_array * candidates); + /// @details Apply temperature and entropy LLAMA_API void llama_sampling_temp( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float temp); + llama_token_data_array * candidates); /// @details Apply constraints from grammar LLAMA_API void llama_sampling_grammar( @@ -1075,6 +1057,7 @@ extern "C" { /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + // TODO: update signature LLAMA_API void llama_sampling_repetition_penalties( struct llama_sampling * smpl, llama_token_data_array * candidates, @@ -1091,34 +1074,12 @@ extern "C" { LLAMA_API void llama_sampling_cfg( struct llama_sampling * smpl, float * logits, - float * logits_guidance, - float scale); - - /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + float * logits_guidance); + + /// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. LLAMA_API llama_token llama_sampling_sample_mirostat( struct llama_sampling * smpl, - llama_token_data_array * candidates, - float tau, - float eta, - int32_t m, - float * mu); - - /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sampling_sample_mirostat_v2( - struct llama_sampling * smpl, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); + llama_token_data_array * candidates); /// @details Selects the token with the highest probability. /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5ba41e949fb3d..f4d666f731133 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -55,9 +55,6 @@ struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smp result->grammar_str = smpl.grammar_str; result->grammar_root = smpl.grammar_root; - result->cfg_prompt = smpl.cfg_prompt; - result->cfg_scale = smpl.cfg_scale; - result->logit_bias = smpl.logit_bias; if (smpl.grammar) { @@ -103,16 +100,6 @@ void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * } } -void llama_sampling_set_cfg_impl(struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale) { - if (cfg_prompt != nullptr && cfg_prompt[0] != '\0') { - smpl.cfg_prompt = cfg_prompt; - } else { - smpl.cfg_prompt.clear(); - } - - smpl.cfg_scale = cfg_scale; -} - void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { smpl.logit_bias.clear(); smpl.logit_bias.reserve(n_logit_bias); @@ -122,7 +109,7 @@ void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_ } } -void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { +void llama_sampling_softmax_impl(llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order @@ -147,7 +134,7 @@ void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_d } } -void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; @@ -222,12 +209,12 @@ void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_dat candidates->size = k; } -void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sampling_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -248,7 +235,7 @@ void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_ar candidates->size = last_idx; } -void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -303,12 +290,12 @@ void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_dat } } -void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sampling_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -357,7 +344,7 @@ void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_dat candidates->size = last_idx; } -void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -365,7 +352,7 @@ void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_ } // Compute the softmax of logits and calculate entropy - llama_sampling_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -415,7 +402,7 @@ void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_ candidates->sorted = false; } -void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { +void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -424,7 +411,7 @@ void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_ // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sampling_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -476,20 +463,17 @@ void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_ #endif } -void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { +void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } } -void llama_sampling_grammar_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { - if (smpl.grammar) { - llama_grammar_apply_impl(*smpl.grammar, candidates); - } +void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { + llama_grammar_apply_impl(grammar, candidates); } void llama_sampling_repetition_penalties_impl( - struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, @@ -529,11 +513,10 @@ void llama_sampling_repetition_penalties_impl( candidates->sorted = false; } -void llama_sampling_apply_guidance_impl( +void llama_sampling_cfg_impl( struct llama_sampling & smpl, float * logits, - float * logits_guidance, - float scale) { + float * logits_guidance) { const auto n_vocab = smpl.vocab.n_vocab; llama_log_softmax(logits, n_vocab); @@ -543,14 +526,12 @@ void llama_sampling_apply_guidance_impl( auto & l = logits[i]; const auto & g = logits_guidance[i]; - l = scale * (l - g) + g; + l = smpl.params.cfg_scale * (l - g) + g; } } -llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - const float n_vocab = float(smpl.vocab.n_vocab); - - llama_sampling_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_sampling_softmax_impl(candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -566,11 +547,11 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, ll // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sampling_top_k_impl(smpl, candidates, int(k), 1); - llama_token X = llama_sampling_sample_impl(smpl, candidates); + llama_sampling_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampling_sample_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -580,17 +561,17 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, ll float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; return X; } -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - llama_sampling_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_sampling_softmax_impl(candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; + return -log2f(candidate.p) > mu; })); if (candidates->size == 0) { @@ -598,10 +579,10 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, } // Normalize the probabilities of the remaining words - llama_sampling_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Sample the next word X from the remaining words - llama_token X = llama_sampling_sample_impl(smpl, candidates); + llama_token X = llama_sampling_sample_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -611,12 +592,12 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; return X; } -llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { +llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; @@ -627,8 +608,8 @@ llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, return result; } -llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - llama_sampling_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sampling_softmax_impl(candidates); std::vector probs; probs.reserve(candidates->size); @@ -644,10 +625,6 @@ llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, ll return result; } -llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { - return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng); -} - void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token) { // TODO: implement token storing in history diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 5d142b2ce0ad2..d66f5e4bf6373 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -14,9 +14,6 @@ struct llama_sampling { std::string grammar_str; std::string grammar_root; - std::string cfg_prompt; - float cfg_scale = 1.0f; - std::vector logit_bias; // logit biases to apply // state @@ -27,6 +24,9 @@ struct llama_sampling { struct llama_grammar * grammar = nullptr; + // mirostat sampler state + float mirostat_mu; + mutable int64_t t_total_us = 0; mutable int32_t n_sample = 0; @@ -47,21 +47,19 @@ void llama_sampling_reset_impl(struct llama_sampling & smpl); // TODO: move the API below as member functions of llama_sampling void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); -void llama_sampling_set_cfg_impl (struct llama_sampling & smpl, const char * cfg_prompt, float cfg_scale); void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); -void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); -void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sampling_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp); -void llama_sampling_grammar_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +void llama_sampling_softmax_impl (struct llama_token_data_array * candidates); +void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp); +void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); void llama_sampling_repetition_penalties_impl( - struct llama_sampling & smpl, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, @@ -69,16 +67,27 @@ void llama_sampling_repetition_penalties_impl( float penalty_freq, float penalty_present); -void llama_sampling_apply_guidance_impl( +void llama_sampling_cfg_impl( struct llama_sampling & smpl, float * logits, - float * logits_guidance, - float scale); - -llama_token llama_sampling_sample_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); -llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); + float * logits_guidance); + +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); + +llama_token llama_sampling_sample_greedy_impl (struct llama_token_data_array * candidates); +llama_token llama_sampling_sample_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 3b1de47b01bfa..11fffce9386d7 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -93,18 +93,6 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string return it->second; } -llama_vocab llama_vocab::for_tests(uint32_t n_vocab) { - llama_vocab vocab; - vocab.n_vocab = n_vocab; - vocab.token_to_id.reserve(n_vocab); - vocab.id_to_token.reserve(n_vocab); - for (uint32_t i = 0; i < n_vocab; i++) { - vocab.token_to_id[format("token_%u", i)] = i; - vocab.id_to_token.push_back({ format("token_%u", i), 0.0f, LLAMA_TOKEN_ATTR_NORMAL }); - } - return vocab; -} - static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { return vocab.type; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 1cd2737277a7a..dc4b5f12f7860 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -62,8 +62,6 @@ struct llama_vocab { std::vector precompiled_charsmap; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; - - static llama_vocab for_tests(uint32_t n_vocab); }; // diff --git a/src/llama.cpp b/src/llama.cpp index 02c131bb6548c..2807058cf9e44 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16512,6 +16512,7 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat =*/ 0, /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, + /*.cfg_scale =*/ 1.00f, /*.penalize_nl =*/ false, /*.ignore_eos =*/ false, }; @@ -19116,10 +19117,6 @@ void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * gramm llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); } -void llama_sampling_set_cfg(struct llama_sampling * smpl, const char * cfg_prompt, float cfg_scale) { - llama_sampling_set_cfg_impl(*smpl, cfg_prompt, cfg_scale); -} - void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); } @@ -19127,57 +19124,58 @@ void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_softmax_impl(*smpl, candidates); + llama_sampling_softmax_impl(candidates); } -void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_top_k_impl(*smpl, candidates, k, min_keep); + llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); } -void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_top_p_impl(*smpl, candidates, p, min_keep); + llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); } -void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_min_p_impl(*smpl, candidates, p, min_keep); + llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); } -void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_tail_free_impl(*smpl, candidates, z, min_keep); + llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); } -void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_typical_impl(*smpl, candidates, p, min_keep); + llama_sampling_typical_impl(candidates, smpl->params.typical_p, smpl->params.min_keep); } -void llama_sampling_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { +void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - llama_sampling_entropy_impl(*smpl, candidates_p, min_temp, max_temp, exponent_val); -} - -void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float temp) { - time_meas tm(smpl->t_total_us); + if (smpl->params.dynatemp_range > 0) { + const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); + const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); - llama_sampling_temp_impl(*smpl, candidates_p, temp); + llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); + } else { + llama_sampling_temp_impl(candidates, smpl->params.temp); + } } -void llama_sampling_grammar( - struct llama_sampling * smpl, - llama_token_data_array * candidates) { +void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling - llama_sampling_grammar_impl(*smpl, candidates); + if (smpl->grammar) { + llama_sampling_grammar_impl(candidates, *smpl->grammar); + } } void llama_sampling_repetition_penalties( @@ -19190,33 +19188,42 @@ void llama_sampling_repetition_penalties( float penalty_present) { time_meas tm(smpl->t_total_us); - llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + llama_sampling_repetition_penalties_impl(candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); } void llama_sampling_cfg( struct llama_sampling * smpl, float * logits, - float * logits_guidance, - float scale) { + float * logits_guidance) { time_meas tm(smpl->t_total_us); - llama_sampling_apply_guidance_impl(*smpl, logits, logits_guidance, scale); + llama_sampling_cfg_impl(*smpl, logits, logits_guidance); } -llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { +llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - auto res = llama_sampling_sample_mirostat_impl(*smpl, candidates, tau, eta, m, mu); - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - time_meas tm(smpl->t_total_us); - - auto res = llama_sampling_sample_mirostat_v2_impl(*smpl, candidates, tau, eta, mu); + const auto type = smpl->params.mirostat; + + llama_token res; + + if (type == 1) { + res = llama_sampling_sample_mirostat_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + 100, + smpl->vocab.n_vocab, + smpl->mirostat_mu); + } else if (type == 2) { + res = llama_sampling_sample_mirostat_v2_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + smpl->mirostat_mu); + } else { + GGML_ABORT("invalid mirostat type: %d", type); + } smpl->n_sample++; @@ -19226,7 +19233,7 @@ llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llam llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - auto res = llama_sampling_sample_greedy_impl(*smpl, candidates); + auto res = llama_sampling_sample_greedy_impl(candidates); smpl->n_sample++; @@ -19236,7 +19243,7 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_total_us); - auto res = llama_sampling_sample_impl(*smpl, candidates); + auto res = llama_sampling_sample_impl(candidates, smpl->rng); smpl->n_sample++; diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 371b7e511996f..ae328a90454ae 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,6 +1,5 @@ #include "ggml.h" #include "llama.h" -#include "llama-vocab.h" #include "llama-sampling.h" #ifdef NDEBUG @@ -22,7 +21,6 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -32,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -55,9 +52,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -79,7 +75,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sampling_tail_free_impl(smpl, &candidates_p, z, 1); + llama_sampling_tail_free_impl(&candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -90,7 +86,6 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -101,9 +96,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); std::vector candidates; candidates.reserve(n_vocab); @@ -124,7 +118,7 @@ static void test_typical(const std::vector & probs, const std::vector candidates; candidates.reserve(n_vocab); @@ -150,10 +143,10 @@ static void test_repetition_penalties( } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sampling_softmax_impl(smpl, &candidates_p); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); - llama_sampling_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sampling_softmax_impl(smpl, &candidates_p); + llama_sampling_repetition_penalties_impl(&candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -164,8 +157,6 @@ static void test_repetition_penalties( static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { - llama_sampling smpl(llama_vocab::for_tests(n_vocab)); - std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -180,16 +171,16 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sampling_top_k_impl(smpl, &candidates_p, top_k, 1); break; + case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break; case 'f': GGML_ABORT("tail_free test not implemented"); case 'y': GGML_ABORT("typical test not implemented"); - case 'p': llama_sampling_top_p_impl(smpl, &candidates_p, top_p, 1); break; - case 'm': llama_sampling_min_p_impl(smpl, &candidates_p, min_p, 1); break; + case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break; + case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break; case 't': GGML_ABORT("temperature test not implemented"); default : GGML_ABORT("Unknown sampler"); } - llama_sampling_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests + llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size;