From 39fbaf9f5081ac3180316f6cc39d232af41e25ba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 Jul 2024 16:56:20 +0300 Subject: [PATCH] llama : redirect external API to internal APIs ggml-ci --- common/sampling.cpp | 6 +- include/llama.h | 14 +- src/llama-grammar.cpp | 26 +- src/llama-grammar.h | 22 ++ src/llama-sampling.cpp | 210 ++++++++-------- src/llama-sampling.h | 41 +++- src/llama-vocab.cpp | 446 +++++----------------------------- src/llama-vocab.h | 52 ++++ src/llama.cpp | 540 +++++++++++++++++++++++++++++++++++++++++ 9 files changed, 838 insertions(+), 519 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 96d5fb0212cf6..079e405168dff 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl( llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; // Apply grammar constraints to the single token - llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar); + llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array); // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl( // apply grammar checks before sampling logic if (apply_grammar && ctx_sampling->grammar != NULL) { - llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar); + llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); } return cur_p; @@ -455,6 +455,6 @@ void llama_sampling_accept( ctx_sampling->prev.push_back(id); if (ctx_sampling->grammar != NULL && apply_grammar) { - llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); + llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } } diff --git a/include/llama.h b/include/llama.h index fc5ddcfb15b80..e68cd807e63bd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -965,6 +965,10 @@ extern "C" { bool remove_special, bool unparse_special); + // + // Chat templates + // + /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template @@ -1005,10 +1009,10 @@ extern "C" { /// @details Apply constraints from grammar LLAMA_API void llama_grammar_sample( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar); - LLAMA_API DEPRECATED(bool llama_sample_grammar( + const struct llama_grammar * grammar, + const struct llama_context * ctx, + llama_token_data_array * candidates); + LLAMA_API DEPRECATED(void llama_sample_grammar( struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar), @@ -1016,8 +1020,8 @@ extern "C" { /// @details Accepts the sampled token into the grammar LLAMA_API void llama_grammar_accept_token( - struct llama_context * ctx, struct llama_grammar * grammar, + struct llama_context * ctx, llama_token token); // diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 6b57a5f206d9f..b724da1016e59 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -384,7 +384,7 @@ static bool llama_grammar_detect_left_recursion( // grammar - external // -struct llama_grammar * llama_grammar_init( +struct llama_grammar * llama_grammar_init_impl( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { @@ -441,11 +441,11 @@ struct llama_grammar * llama_grammar_init( return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } -void llama_grammar_free(struct llama_grammar * grammar) { +void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; // redirect elements in stacks to point to new rules @@ -464,8 +464,10 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) return result; } -void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { - GGML_ASSERT(ctx); +void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { + GGML_ASSERT(grammar); + GGML_ASSERT(vocab); + int64_t t_start_sample_us = ggml_time_us(); bool allow_eog = false; @@ -484,9 +486,9 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(id); + const std::string & piece = vocab->cache_token_to_piece.at(id); - if (llama_token_is_eog(llama_get_model(ctx), id)) { + if (llama_token_is_eog(*vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } @@ -503,13 +505,13 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c candidates->data[reject.index].logit = -INFINITY; } - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { +void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (llama_token_is_eog(llama_get_model(ctx), token)) { + if (llama_token_is_eog(*vocab, token)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -518,7 +520,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(token); + const std::string & piece = vocab->cache_token_to_piece.at(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); @@ -533,5 +535,5 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar grammar->partial_utf8 = decoded.second; GGML_ASSERT(!grammar->stacks.empty()); - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index d923076384e77..64e04509dae02 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama-impl.h" struct llama_vocab; +struct llama_sampling; struct llama_grammar { const llama_grammar_rules rules; @@ -13,3 +14,24 @@ struct llama_grammar { }; struct llama_grammar * llama_get_grammar(struct llama_context * ctx); + +struct llama_grammar * llama_grammar_init_impl( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +void llama_grammar_free_impl(struct llama_grammar * grammar); + +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); + +void llama_grammar_sample( + const struct llama_grammar * grammar, + const struct llama_vocab * vocab, + const struct llama_sampling * smpl, + llama_token_data_array * candidates); + +void llama_grammar_accept_token( + struct llama_grammar * grammar, + const struct llama_vocab * vocab, + const struct llama_sampling * smpl, + llama_token token); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d31d47f2eaeb0..d2651885401a3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -7,15 +7,29 @@ #include #include -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} + +void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } - llama_get_sampling(ctx)->rng.seed(seed); + smpl->rng.seed(seed); } -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { +void llama_sample_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); @@ -39,12 +53,12 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c candidates->data[i].p /= cum_sum; } - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sample_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; @@ -120,17 +134,17 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can } candidates->size = k; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax(ctx, candidates); + llama_sample_softmax(smpl, candidates); const int64_t t_start_sample_us = ggml_time_us(); @@ -152,12 +166,12 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -213,17 +227,17 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can candidates->size = i; } - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax(nullptr, candidates); + llama_sample_softmax((struct llama_sampling *) nullptr, candidates); const int64_t t_start_sample_us = ggml_time_us(); // Compute the first and second derivatives @@ -272,12 +286,12 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sample_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -285,7 +299,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c } // Compute the softmax of logits and calculate entropy - llama_sample_softmax(nullptr, candidates); + llama_sample_softmax((struct llama_sampling *) nullptr, candidates); const int64_t t_start_sample_us = ggml_time_us(); @@ -336,34 +350,34 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c candidates->size = new_candidates.size(); candidates->sorted = false; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { +void llama_sample_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { const int64_t t_start_sample_us = ggml_time_us(); // no need to do anything if there is only one (or zero) candidates - if(candidates_p->size <= 1) { + if(candidates->size <= 1) { return; } // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates_p->size); + float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax(nullptr, candidates_p); + llama_sample_softmax((struct llama_sampling *) nullptr, candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; - for (size_t i = 0; i < candidates_p->size; ++i) { - float prob = candidates_p->data[i].p; + for (size_t i = 0; i < candidates->size; ++i) { + float prob = candidates->data[i].p; if (prob > 0.0f) { // Ensure no log(0) entropy -= prob * logf(prob); } } - // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above) + // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above) float normalized_entropy = entropy / max_entropy; // Map the normalized entropy to the desired temperature range using the power function @@ -379,55 +393,55 @@ void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * c #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= dyn_temp; + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].logit /= dyn_temp; } // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates_p->data[0].logit; + double max_l_double = candidates->data[0].logit; double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates_p->size; ++i) { - double p = exp(candidates_p->data[i].logit - max_l_double); - candidates_p->data[i].p = p; // Store the scaled probability + for (size_t i = 0; i < candidates->size; ++i) { + double p = exp(candidates->data[i].logit - max_l_double); + candidates->data[i].p = p; // Store the scaled probability cum_sum_double += p; } - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities } #ifdef DEBUG // Print the updated top 25 probabilities after temperature scaling LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); + for (size_t i = 0; i < 25 && i < candidates->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); } #endif - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { +void llama_sample_temp(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { const int64_t t_start_sample_us = ggml_time_us(); - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= temp; + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].logit /= temp; } - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } @@ -462,34 +476,20 @@ void llama_sample_repetition_penalties( candidates->sorted = false; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -static void llama_log_softmax(float * array, size_t size) { - float max_l = *std::max_element(array, array + size); - float sum = 0.f; - for (size_t i = 0; i < size; ++i) { - float p = expf(array[i] - max_l); - sum += p; - array[i] = p; - } - - for (size_t i = 0; i < size; ++i) { - array[i] = logf(array[i] / sum); + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } } void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - GGML_ASSERT(ctx); + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale) { + GGML_ASSERT(smpl); const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = llama_get_sampling(ctx)->n_vocab; + const auto n_vocab = smpl->n_vocab; llama_log_softmax(logits, n_vocab); llama_log_softmax(logits_guidance, n_vocab); @@ -501,17 +501,17 @@ void llama_sample_apply_guidance( l = scale * (l - g) + g; } - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(ctx); +llama_token llama_sample_token_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + GGML_ASSERT(smpl); - const int32_t n_vocab = float(llama_get_sampling(ctx)->n_vocab); + const int32_t n_vocab = float(smpl->n_vocab); int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates); + llama_sample_softmax((struct llama_sampling *) nullptr, candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -530,9 +530,9 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k(nullptr, candidates, int(k), 1); - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token(ctx, candidates); + llama_sample_top_k((struct llama_sampling *) nullptr, candidates, int(k), 1); + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_token X = llama_sample_token(smpl, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value @@ -545,15 +545,15 @@ llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_ // Update mu using the learning rate and error *mu = *mu - eta * e; - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; return X; } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { +llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); - llama_sample_softmax(ctx, candidates); + llama_sample_softmax(smpl, candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -564,15 +564,15 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok candidates->size = 1; } - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } // Normalize the probabilities of the remaining words - llama_sample_softmax(ctx, candidates); + llama_sample_softmax(smpl, candidates); // Sample the next word X from the remaining words - llama_token X = llama_sample_token(ctx, candidates); + llama_token X = llama_sample_token(smpl, candidates); t_start_sample_us = ggml_time_us(); // Compute error as the difference between observed surprise and target surprise value @@ -585,13 +585,13 @@ llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_tok // Update mu using the learning rate and error *mu = *mu - eta * e; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } return X; } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { const int64_t t_start_sample_us = ggml_time_us(); // Find max element @@ -600,18 +600,18 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da }); llama_token result = max_iter->id; - if (ctx) { - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_get_sampling(ctx)->n_sample++; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; } return result; } -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(ctx); +llama_token llama_sample_token_with_rng(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { + GGML_ASSERT(smpl); const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates); + llama_sample_softmax((struct llama_sampling *) nullptr, candidates); std::vector probs; probs.reserve(candidates->size); @@ -624,12 +624,12 @@ llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_ llama_token result = candidates->data[idx].id; - llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_get_sampling(ctx)->n_sample++; + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng(ctx, candidates, llama_get_sampling(ctx)->rng); +llama_token llama_sample_token(struct llama_sampling * smpl, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(smpl, candidates, smpl->rng); } - diff --git a/src/llama-sampling.h b/src/llama-sampling.h index cf05e357e363c..d935c20227697 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -7,15 +7,48 @@ struct llama_sampling { std::mt19937 rng; - int64_t t_sample_us = 0; - - int32_t n_sample = 0; int32_t n_vocab = 0; - void reset_timings() { + mutable int64_t t_sample_us = 0; + mutable int32_t n_sample = 0; + + void reset_timings() const { t_sample_us = 0; n_sample = 0; } }; struct llama_sampling * llama_get_sampling(struct llama_context * ctx); + +void llama_set_rng_seed(struct llama_sampling * smpl, uint32_t seed); + +void llama_sample_softmax (struct llama_sampling * smpl, llama_token_data_array * candidates); +void llama_sample_top_k (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sample_top_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_min_p (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sample_typical (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_entropy (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sample_temp (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); + +void llama_sample_repetition_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); + +void llama_sample_apply_guidance( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale); + +llama_token llama_sample_token_mirostat (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); +llama_token llama_sample_token_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); +llama_token llama_sample_token_greedy (struct llama_sampling * smpl, llama_token_data_array * candidates); +llama_token llama_sample_token_with_rng (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sample_token (struct llama_sampling * smpl, llama_token_data_array * candidates); + diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 443e26ef651d2..592ed06405381 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -12,29 +12,10 @@ #include #include -#if __cplusplus >= 202000L - #define LU8(x) (const char*)(u8##x) -#else - #define LU8(x) u8##x -#endif - // // helpers // -// trim whitespace from the beginning and end of a string -static std::string trim(const std::string & str) { - size_t start = 0; - size_t end = str.size(); - while (start < end && isspace(str[start])) { - start += 1; - } - while (end > start && isspace(str[end - 1])) { - end -= 1; - } - return str.substr(start, end - start); -} - static void replace_all(std::string & s, const std::string & search, const std::string & replace) { std::string result; for (size_t pos = 0; ; pos += search.length()) { @@ -1445,106 +1426,89 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, return output; } -const char * llama_token_get_text(const struct llama_model * model, llama_token token) { - const struct llama_vocab * vocab = llama_get_vocab(model); - GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE); - return vocab->id_to_token[token].text.c_str(); +const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].text.c_str(); } -float llama_token_get_score(const struct llama_model * model, llama_token token) { - const struct llama_vocab * vocab = llama_get_vocab(model); - GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE); - return vocab->id_to_token[token].score; +float llama_token_get_score(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].score; } -llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) { - const struct llama_vocab * vocab = llama_get_vocab(model); - GGML_ASSERT(vocab->type != LLAMA_VOCAB_TYPE_NONE); - return vocab->id_to_token[token].attr; +llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].attr; } -bool llama_token_is_eog(const struct llama_model * model, llama_token token) { +bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token) { return token != -1 && ( - token == llama_token_eos(model) || - token == llama_token_eot(model) + token == llama_token_eos(vocab) || + token == llama_token_eot(vocab) ); } -bool llama_token_is_control(const struct llama_model * model, llama_token token) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return llama_is_control_token(*vocab, token); +bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token) { + return llama_is_control_token(vocab, token); } -llama_token llama_token_bos(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_bos_id; +llama_token llama_token_bos(const struct llama_vocab & vocab) { + return vocab.special_bos_id; } -llama_token llama_token_eos(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_eos_id; +llama_token llama_token_eos(const struct llama_vocab & vocab) { + return vocab.special_eos_id; } -llama_token llama_token_cls(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_cls_id; +llama_token llama_token_cls(const struct llama_vocab & vocab) { + return vocab.special_cls_id; } -llama_token llama_token_sep(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_sep_id; +llama_token llama_token_sep(const struct llama_vocab & vocab) { + return vocab.special_sep_id; } -llama_token llama_token_nl(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->linefeed_id; +llama_token llama_token_nl(const struct llama_vocab & vocab) { + return vocab.linefeed_id; } -int32_t llama_add_bos_token(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->tokenizer_add_bos; +llama_token llama_token_pad(const struct llama_vocab & vocab) { + return vocab.special_pad_id; } -int32_t llama_add_eos_token(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->tokenizer_add_eos; +int32_t llama_add_bos_token(const struct llama_vocab & vocab) { + return vocab.tokenizer_add_bos; } -llama_token llama_token_prefix(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_prefix_id; +int32_t llama_add_eos_token(const struct llama_vocab & vocab) { + return vocab.tokenizer_add_eos; } -llama_token llama_token_middle(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_middle_id; +llama_token llama_token_prefix(const struct llama_vocab & vocab) { + return vocab.special_prefix_id; } -llama_token llama_token_suffix(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_suffix_id; +llama_token llama_token_middle(const struct llama_vocab & vocab) { + return vocab.special_middle_id; } -llama_token llama_token_eot(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_eot_id; +llama_token llama_token_suffix(const struct llama_vocab & vocab) { + return vocab.special_suffix_id; } -llama_token llama_token_pad(const struct llama_model * model) { - const struct llama_vocab * vocab = llama_get_vocab(model); - return vocab->special_pad_id; +llama_token llama_token_eot(const struct llama_vocab & vocab) { + return vocab.special_eot_id; } int32_t llama_tokenize( - const struct llama_model * model, + const struct llama_vocab & vocab, const char * text, int32_t text_len, llama_token * tokens, int32_t n_tokens_max, bool add_special, bool parse_special) { - const struct llama_vocab * vocab = llama_get_vocab(model); - auto res = llama_tokenize_internal(*vocab, std::string(text, text_len), add_special, parse_special); + auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); @@ -1578,10 +1542,10 @@ static std::string llama_decode_text(const std::string & text) { } // does not write null-terminator to buf -int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) { +int32_t llama_token_to_piece(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) { // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843 static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL; - const llama_token_attr attr = llama_token_get_attr(model, token); + const llama_token_attr attr = llama_token_get_attr(vocab, token); if (!special && (attr & attr_special)) { return 0; } @@ -1600,11 +1564,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token return (int32_t) size; }; - const struct llama_vocab * vocab = llama_get_vocab(model); - // if we have a cache - use it { - const auto & cache = vocab->cache_token_to_piece; + const auto & cache = vocab.cache_token_to_piece; if (!cache.empty()) { const auto & result = cache.at(token); @@ -1612,9 +1574,9 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } } - if (0 <= token && token < llama_n_vocab(model)) { - const std::string & token_text = vocab->id_to_token[token].text; - switch (llama_vocab_get_type(*vocab)) { + if (0 <= token && token < (int32_t) vocab.id_to_token.size()) { + const std::string & token_text = vocab.id_to_token[token].text; + switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_SPM: case LLAMA_VOCAB_TYPE_UGM: { @@ -1627,7 +1589,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token llama_unescape_whitespace(result); return _try_copy(result.data(), result.size()); } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { - char byte = (char) llama_token_to_byte(*vocab, token); + char byte = (char) llama_token_to_byte(vocab, token); return _try_copy((char*) &byte, 1); } break; @@ -1647,11 +1609,12 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token GGML_ASSERT(false); } } + return 0; } int32_t llama_detokenize( - const struct llama_model * model, + const struct llama_vocab & vocab, const llama_token * tokens, int32_t n_tokens, char * text, @@ -1661,28 +1624,26 @@ int32_t llama_detokenize( int32_t avail = text_len_max; int32_t total = 0; - const struct llama_vocab * vocab = llama_get_vocab(model); - // remove the leading space - bool remove_space = vocab->tokenizer_add_space_prefix; + bool remove_space = vocab.tokenizer_add_space_prefix; - if (remove_special && vocab->tokenizer_add_bos) { - if (n_tokens > 0 && tokens[0] == vocab->special_bos_id) { + if (remove_special && vocab.tokenizer_add_bos) { + if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) { remove_space = false; n_tokens--; tokens++; } } - if (remove_special && vocab->tokenizer_add_eos) { - if (n_tokens > 0 && tokens[n_tokens-1] == vocab->special_eos_id) { + if (remove_special && vocab.tokenizer_add_eos) { + if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) { n_tokens--; } } for (int32_t i = 0; i < n_tokens; ++i) { GGML_ASSERT(avail >= 0); - int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special); + int32_t n_chars = llama_token_to_piece(vocab, tokens[i], text, avail, remove_space, unparse_special); remove_space = false; if (n_chars < 0) { avail = 0; @@ -1698,7 +1659,7 @@ int32_t llama_detokenize( return -total; } - if (vocab->tokenizer_clean_spaces) { + if (vocab.tokenizer_clean_spaces) { text -= total; // restart text // first pass: characters ?!., //TODO: where do these characters come from? @@ -1758,298 +1719,3 @@ int32_t llama_detokenize( return total <= text_len_max ? total : -total; } - -// -// chat templates -// - -// Simple version of "llama_apply_chat_template" that only works with strings -// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. -static int32_t llama_chat_apply_template_internal( - const std::string & tmpl, - const std::vector & chat, - std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 - std::stringstream ss; - auto tmpl_contains = [&tmpl](std::string haystack) -> bool { - return tmpl.find(haystack) != std::string::npos; - }; - if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { - // chatml template - for (auto message : chat) { - ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; - } - if (add_ass) { - ss << "<|im_start|>assistant\n"; - } - } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { - // llama2 template and its variants - // [variant] support system message - bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; - // [variant] space before + after response - bool space_around_response = tmpl_contains("' ' + eos_token"); - // [variant] add BOS inside history - bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); - // [variant] trim spaces from the input message - bool strip_message = tmpl_contains("content.strip()"); - // construct the prompt - bool is_inside_turn = true; // skip BOS at the beginning - ss << "[INST] "; - for (auto message : chat) { - std::string content = strip_message ? trim(message->content) : message->content; - std::string role(message->role); - if (!is_inside_turn) { - is_inside_turn = true; - ss << (add_bos_inside_history ? "[INST] " : "[INST] "); - } - if (role == "system") { - if (support_system_message) { - ss << "<>\n" << content << "\n<>\n\n"; - } else { - // if the model does not support system message, we still include it in the first message, but without <> - ss << content << "\n"; - } - } else if (role == "user") { - ss << content << " [/INST]"; - } else { - ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; - is_inside_turn = false; - } - } - // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { - // Phi 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; - } - if (add_ass) { - ss << "<|assistant|>\n"; - } - } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { - // zephyr template - for (auto message : chat) { - ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; - } - if (add_ass) { - ss << "<|assistant|>\n"; - } - } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { - // mlabonne/AlphaMonarch-7B template (the is included inside history) - for (auto message : chat) { - std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message - ss << bos << message->role << "\n" << message->content << "\n"; - } - if (add_ass) { - ss << "assistant\n"; - } - } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) { - // google/gemma-7b-it - std::string system_prompt = ""; - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken - system_prompt = trim(message->content); - continue; - } - // in gemma, "assistant" is "model" - role = role == "assistant" ? "model" : message->role; - ss << "" << role << "\n"; - if (!system_prompt.empty() && role != "model") { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } - ss << trim(message->content) << "\n"; - } - if (add_ass) { - ss << "model\n"; - } - } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { - // OrionStarAI/Orion-14B-Chat - std::string system_prompt = ""; - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // there is no system message support, we will merge it with user prompt - system_prompt = message->content; - continue; - } else if (role == "user") { - ss << "Human: "; - if (!system_prompt.empty()) { - ss << system_prompt << "\n\n"; - system_prompt = ""; - } - ss << message->content << "\n\nAssistant: "; - } else { - ss << message->content << ""; - } - } - } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { - // openchat/openchat-3.5-0106, - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << message->content << "<|end_of_turn|>"; - } else { - role[0] = toupper(role[0]); - ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; - } - } - if (add_ass) { - ss << "GPT4 Correct Assistant:"; - } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { - // eachadea/vicuna-13b-1.1 (and Orca variant) - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { - ss << "SYSTEM: " << message->content << "\n"; - } else { - ss << message->content << "\n\n"; - } - } else if (role == "user") { - ss << "USER: " << message->content << "\n"; - } else if (role == "assistant") { - ss << "ASSISTANT: " << message->content << "\n"; - } - } - if (add_ass) { - ss << "ASSISTANT:"; - } - } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { - // deepseek-ai/deepseek-coder-33b-instruct - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << message->content; - } else if (role == "user") { - ss << "### Instruction:\n" << message->content << "\n"; - } else if (role == "assistant") { - ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; - } - } - if (add_ass) { - ss << "### Response:\n"; - } - } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { - // CohereForAI/c4ai-command-r-plus - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } else if (role == "user") { - ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } else if (role == "assistant") { - ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; - } - } - if (add_ass) { - ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; - } - } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { - // Llama 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; - } - if (add_ass) { - ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; - } - } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { - // chatglm3-6b - ss << "[gMASK]" << "sop"; - for (auto message : chat) { - std::string role(message->role); - ss << "<|" << role << "|>" << "\n " << message->content; - } - if (add_ass) { - ss << "<|assistant|>"; - } - } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { - ss << "[gMASK]" << ""; - for (auto message : chat) { - std::string role(message->role); - ss << "<|" << role << "|>" << "\n" << message->content; - } - if (add_ass) { - ss << "<|assistant|>"; - } - } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { - // MiniCPM-3B-OpenHermes-2.5-v2-GGUF - for (auto message : chat) { - std::string role(message->role); - if (role == "user") { - ss << LU8("<用户>"); - ss << trim(message->content); - ss << ""; - } else { - ss << trim(message->content); - } - } - } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { - // DeepSeek-V2 - for (auto message : chat) { - std::string role(message->role); - if (role == "system") { - ss << message->content << "\n\n"; - } else if (role == "user") { - ss << "User: " << message->content << "\n\n"; - } else if (role == "assistant") { - ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); - } - } - if (add_ass) { - ss << "Assistant:"; - } - } else { - // template not supported - return -1; - } - dest = ss.str(); - return dest.size(); -} - -int32_t llama_chat_apply_template( - const struct llama_model * model, - const char * tmpl, - const struct llama_chat_message * chat, - size_t n_msg, - bool add_ass, - char * buf, - int32_t length) { - std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); - if (tmpl == nullptr) { - GGML_ASSERT(model != nullptr); - // load template from model - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res < 0) { - // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "chatml"; // see llama_chat_apply_template_internal - } else { - curr_tmpl = std::string(model_template.data(), model_template.size()); - } - } - - // format the chat to string - std::vector chat_vec; - chat_vec.resize(n_msg); - for (size_t i = 0; i < n_msg; i++) { - chat_vec[i] = &chat[i]; - } - - std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); - if (res < 0) { - return res; - } - if (buf && length > 0) { - strncpy(buf, formatted_chat.c_str(), length); - } - return res; -} - diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 99b917fbfd639..873021125d1aa 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -72,3 +72,55 @@ std::vector llama_tokenize_internal( bool parse_special = false); llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); + +const char * llama_token_get_text(const struct llama_vocab & vocab, llama_token token); + +float llama_token_get_score(const struct llama_vocab & vocab, llama_token token); + +llama_token_attr llama_token_get_attr(const struct llama_vocab & vocab, llama_token token); + +bool llama_token_is_eog(const struct llama_vocab & vocab, llama_token token); + +bool llama_token_is_control(const struct llama_vocab & vocab, llama_token token); + +llama_token llama_token_bos(const struct llama_vocab & vocab); +llama_token llama_token_eos(const struct llama_vocab & vocab); +llama_token llama_token_cls(const struct llama_vocab & vocab); +llama_token llama_token_sep(const struct llama_vocab & vocab); +llama_token llama_token_nl (const struct llama_vocab & vocab); +llama_token llama_token_pad(const struct llama_vocab & vocab); + +int32_t llama_add_bos_token(const struct llama_vocab & vocab); +int32_t llama_add_eos_token(const struct llama_vocab & vocab); + +llama_token llama_token_prefix(const struct llama_vocab & vocab); +llama_token llama_token_middle(const struct llama_vocab & vocab); +llama_token llama_token_suffix(const struct llama_vocab & vocab); +llama_token llama_token_eot (const struct llama_vocab & vocab); + +int32_t llama_tokenize( + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special); + +// does not write null-terminator to buf +int32_t llama_token_to_piece( + const struct llama_vocab & vocab, + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special); + +int32_t llama_detokenize( + const struct llama_vocab & vocab, + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special); diff --git a/src/llama.cpp b/src/llama.cpp index 9737911fdbfa2..fac782351d168 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -48,6 +48,12 @@ #include #endif +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + #include #include #include @@ -85,6 +91,19 @@ // helpers // +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + static void replace_all(std::string & s, const std::string & search, const std::string & replace) { std::string result; for (size_t pos = 0; ; pos += search.length()) { @@ -18487,6 +18506,527 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id return it->second.data(); } +// +// vocab +// + +const char * llama_token_get_text(const struct llama_model * model, llama_token token) { + return llama_token_get_text(model->vocab, token); +} + +float llama_token_get_score(const struct llama_model * model, llama_token token) { + return llama_token_get_score(model->vocab, token); +} + +enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) { + return llama_token_get_attr(model->vocab, token); +} + +bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + return llama_token_is_eog(model->vocab, token); +} + +bool llama_token_is_control(const struct llama_model * model, llama_token token) { + return llama_token_is_control(model->vocab, token); +} + +llama_token llama_token_bos(const struct llama_model * model) { + return llama_token_bos(model->vocab); +} + +llama_token llama_token_eos(const struct llama_model * model) { + return llama_token_eos(model->vocab); +} + +llama_token llama_token_cls(const struct llama_model * model) { + return llama_token_cls(model->vocab); +} + +llama_token llama_token_sep(const struct llama_model * model) { + return llama_token_sep(model->vocab); +} + +llama_token llama_token_nl (const struct llama_model * model) { + return llama_token_nl (model->vocab); +} + +llama_token llama_token_pad(const struct llama_model * model) { + return llama_token_pad(model->vocab); +} + +int32_t llama_add_bos_token(const struct llama_model * model) { + return llama_add_bos_token(model->vocab); +} + +int32_t llama_add_eos_token(const struct llama_model * model) { + return llama_add_eos_token(model->vocab); +} + +llama_token llama_token_prefix(const struct llama_model * model) { + return llama_token_prefix(model->vocab); +} + +llama_token llama_token_middle(const struct llama_model * model) { + return llama_token_middle(model->vocab); +} + +llama_token llama_token_suffix(const struct llama_model * model) { + return llama_token_suffix(model->vocab); +} + +llama_token llama_token_eot(const struct llama_model * model) { + return llama_token_eot(model->vocab); +} + +// +// tokenization +// + +int32_t llama_tokenize( + const struct llama_model * model, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) { + return llama_tokenize(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special); +} + +int32_t llama_token_to_piece( + const struct llama_model * model, + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special) { + return llama_token_to_piece(model->vocab, token, buf, length, lstrip, special); +} + +int32_t llama_detokenize( + const struct llama_model * model, + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) { + return llama_detokenize(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special); +} + +// +// chat templates +// + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +static int32_t llama_chat_apply_template_internal( + const std::string & tmpl, + const std::vector & chat, + std::string & dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + auto tmpl_contains = [&tmpl](std::string haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { + // llama2 template and its variants + // [variant] support system message + bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; + // [variant] space before + after response + bool space_around_response = tmpl_contains("' ' + eos_token"); + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + // [variant] trim spaces from the input message + bool strip_message = tmpl_contains("content.strip()"); + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + is_inside_turn = false; + } + } + // llama2 templates seem to not care about "add_generation_prompt" + } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } + } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message support, we will merge it with user prompt + system_prompt = message->content; + continue; + } else if (role == "user") { + ss << "Human: "; + if (!system_prompt.empty()) { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << message->content << "\n\nAssistant: "; + } else { + ss << message->content << ""; + } + } + } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106, + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "<|end_of_turn|>"; + } else { + role[0] = toupper(role[0]); + ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; + } + } + if (add_ass) { + ss << "GPT4 Correct Assistant:"; + } + } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // Orca-Vicuna variant uses a system prefix + if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + ss << "SYSTEM: " << message->content << "\n"; + } else { + ss << message->content << "\n\n"; + } + } else if (role == "user") { + ss << "USER: " << message->content << "\n"; + } else if (role == "assistant") { + ss << "ASSISTANT: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "ASSISTANT:"; + } + } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + // deepseek-ai/deepseek-coder-33b-instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content; + } else if (role == "user") { + ss << "### Instruction:\n" << message->content << "\n"; + } else if (role == "assistant") { + ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; + } + } + if (add_ass) { + ss << "### Response:\n"; + } + } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << LU8("<用户>"); + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } + } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + // DeepSeek-V2 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +int32_t llama_chat_apply_template( + const struct llama_model * model, + const char * tmpl, + const struct llama_chat_message * chat, + size_t n_msg, + bool add_ass, + char * buf, + int32_t length) { + std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); + if (tmpl == nullptr) { + GGML_ASSERT(model != nullptr); + // load template from model + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res < 0) { + // worst case: there is no information about template, we will use chatml by default + curr_tmpl = "chatml"; // see llama_chat_apply_template_internal + } else { + curr_tmpl = std::string(model_template.data(), model_template.size()); + } + } + + // format the chat to string + std::vector chat_vec; + chat_vec.resize(n_msg); + for (size_t i = 0; i < n_msg; i++) { + chat_vec[i] = &chat[i]; + } + + std::string formatted_chat; + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + if (res < 0) { + return res; + } + if (buf && length > 0) { + strncpy(buf, formatted_chat.c_str(), length); + } + return res; +} + +// +// grammar +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + return llama_grammar_init_impl(rules, n_rules, start_rule_index); +} + +void llama_grammar_free(struct llama_grammar * grammar) { + llama_grammar_free_impl(grammar); +} + +struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { + return llama_grammar_copy_impl(grammar); +} + +void llama_grammar_sample( + const struct llama_grammar * grammar, + const struct llama_context * ctx, + llama_token_data_array * candidates) { + llama_grammar_sample(grammar, &ctx->model.vocab, &ctx->sampling, candidates); +} + +void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar) { + llama_grammar_sample(grammar, ctx, candidates); +} + +void llama_grammar_accept_token( + struct llama_grammar * grammar, + struct llama_context * ctx, + llama_token token) { + llama_grammar_accept_token(grammar, &ctx->model.vocab, &ctx->sampling, token); +} + +// +// sampling +// + +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + llama_set_rng_seed(&ctx->sampling, seed); +} + +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { + llama_sample_softmax(ctx ? &ctx->sampling : nullptr, candidates); +} + +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + llama_sample_top_k(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); +} + +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + llama_sample_top_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +} + +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + llama_sample_min_p(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +} + +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + llama_sample_tail_free(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); +} + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + llama_sample_typical(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +} + +void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { + llama_sample_entropy(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); +} + +void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + llama_sample_temp(ctx ? &ctx->sampling : nullptr, candidates_p, temp); +} + +void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + llama_sample_repetition_penalties(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); +} + +void llama_sample_apply_guidance( + struct llama_context * ctx, + float * logits, + float * logits_guidance, + float scale) { + llama_sample_apply_guidance(&ctx->sampling, logits, logits_guidance, scale); +} + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + return llama_sample_token_mirostat(&ctx->sampling, candidates, tau, eta, m, mu); +} + +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { + return llama_sample_token_mirostat_v2(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); +} + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_greedy(ctx ? &ctx->sampling : nullptr, candidates); +} + +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { + return llama_sample_token_with_rng(&ctx->sampling, candidates, rng); +} + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(&ctx->sampling, candidates, ctx->sampling.rng); +} + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {