From 516746a5968272444ea5783664d102971dcd217f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 12 Jul 2024 13:28:10 +0300 Subject: [PATCH] llama : move sampling code into llama-sampling ggml-ci --- Makefile | 9 + include/llama.h | 12 +- src/CMakeLists.txt | 1 + src/llama-impl.h | 50 +++ src/llama-sampling.cpp | 635 +++++++++++++++++++++++++++++++++++ src/llama-sampling.h | 21 ++ src/llama.cpp | 729 ++--------------------------------------- 7 files changed, 758 insertions(+), 699 deletions(-) create mode 100644 src/llama-impl.h create mode 100644 src/llama-sampling.cpp create mode 100644 src/llama-sampling.h diff --git a/Makefile b/Makefile index bec332ec57b55..421dc554783bc 100644 --- a/Makefile +++ b/Makefile @@ -868,6 +868,7 @@ OBJ_GGML += \ OBJ_LLAMA = \ src/llama.o \ + src/llama-sampling.o \ src/unicode.o \ src/unicode-data.o @@ -1047,6 +1048,7 @@ src/unicode-data.o: \ src/llama.o: \ src/llama.cpp \ + src/llama-impl.h \ src/unicode.h \ include/llama.h \ ggml/include/ggml-cuda.h \ @@ -1056,6 +1058,13 @@ src/llama.o: \ ggml/include/ggml-backend.h $(CXX) $(CXXFLAGS) -c $< -o $@ +src/llama-sampling.o: \ + src/llama-sampling.cpp \ + src/llama-sampling.h \ + src/llama-impl.h \ + include/llama.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + $(LIB_LLAMA): \ $(OBJ_LLAMA) \ $(LIB_GGML) diff --git a/include/llama.h b/include/llama.h index c0fb53060eae4..54236cccb427a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1081,12 +1081,6 @@ extern "C" { llama_token_data_array * candidates, float temp); - /// @details Apply constraints from grammar - LLAMA_API void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar); - /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1124,6 +1118,12 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); + /// @details Accepts the sampled token into the grammar LLAMA_API void llama_grammar_accept_token( struct llama_context * ctx, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c2049df79c212..9930a056e0fcf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,6 +14,7 @@ endif() add_library(llama ../include/llama.h llama.cpp + llama-sampling.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-impl.h b/src/llama-impl.h new file mode 100644 index 0000000000000..2a9110e54b141 --- /dev/null +++ b/src/llama-impl.h @@ -0,0 +1,50 @@ +#pragma once + +#define LLAMA_API_INTERNAL +#include "llama.h" + +#include +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +// bump if necessary +#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_LAYERS 256 +#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +LLAMA_ATTRIBUTE_FORMAT(2, 3) +void llama_log_internal (ggml_log_level level, const char * format, ...); +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp new file mode 100644 index 0000000000000..d31d47f2eaeb0 --- /dev/null +++ b/src/llama-sampling.cpp @@ -0,0 +1,635 @@ +#include "llama-sampling.h" + +#include +#include +#include +#include +#include +#include + +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + + llama_get_sampling(ctx)->rng.seed(seed); +} + +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = candidates->data[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; + } + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast + // if (k >= (int32_t)candidates->size) { + // return; + // } + + const int64_t t_start_sample_us = ggml_time_us(); + + if (k <= 0) { + k = candidates->size; + } + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); + + // Sort scores in descending order + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k <= 128) { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucker_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= k) break; + } + std::vector tmp_tokens(nhave); + auto ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + candidates->sorted = true; + } + candidates->size = k; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } + + llama_sample_softmax(ctx, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= p && i + 1 >= min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + candidates->size = last_idx; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p <= 0.0f || !candidates->size) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + bool min_p_applied = false; + + // if the candidates aren't sorted, try the unsorted implementation first + if (!candidates->sorted) { + std::vector filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < candidates->size; ++i) { + max_logit = std::max(max_logit, candidates->data[i].logit); + } + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].logit >= min_logit) { + filtered_tokens.push_back(candidates->data[i]); + } + } + + // if we have enough values the operation was a success + if (filtered_tokens.size() >= min_keep) { + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + candidates->size = filtered_tokens.size(); + min_p_applied = true; + } + } + + // if the candidates are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].logit < min_logit && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + } + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { + return; + } + + llama_sample_softmax(nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the first and second derivatives + std::vector first_derivatives(candidates->size - 1); + std::vector second_derivatives(candidates->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = std::abs(second_derivatives[i]); + } + + // Normalize the second derivatives + { + const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + + if (second_derivatives_sum > 1e-6f) { + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + } else { + for (float & value : second_derivatives) { + value = 1.0f / second_derivatives.size(); + } + } + } + + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + candidates->size = last_idx; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(nullptr, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(candidates->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates->data[idx]); + } + + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); + candidates->sorted = false; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { + const int64_t t_start_sample_us = ggml_time_us(); + + // no need to do anything if there is only one (or zero) candidates + if(candidates_p->size <= 1) { + return; + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / candidates_p->size); + + llama_sample_softmax(nullptr, candidates_p); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < candidates_p->size; ++i) { + float prob = candidates_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above) + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + +#ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); +#endif + + // Apply the dynamically calculated temperature scaling + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= dyn_temp; + } + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + double max_l_double = candidates_p->data[0].logit; + double cum_sum_double = 0.0; + for (size_t i = 0; i < candidates_p->size; ++i) { + double p = exp(candidates_p->data[i].logit - max_l_double); + candidates_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + +#ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); + } +#endif + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= temp; + } + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[last_tokens[i]]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates->size; ++i) { + const auto token_iter = token_count.find(candidates->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty_repeat; + } else { + candidates->data[i].logit /= penalty_repeat; + } + + candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; + } + + candidates->sorted = false; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} + +void llama_sample_apply_guidance( + struct llama_context * ctx, + float * logits, + float * logits_guidance, + float scale) { + GGML_ASSERT(ctx); + + const auto t_start_sample_us = ggml_time_us(); + const auto n_vocab = llama_get_sampling(ctx)->n_vocab; + + llama_log_softmax(logits, n_vocab); + llama_log_softmax(logits_guidance, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + auto & l = logits[i]; + const auto & g = logits_guidance[i]; + + l = scale * (l - g) + g; + } + + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + GGML_ASSERT(ctx); + + const int32_t n_vocab = float(llama_get_sampling(ctx)->n_vocab); + + int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_sample_top_k(nullptr, candidates, int(k), 1); + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + return X; +} + +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Normalize the probabilities of the remaining words + llama_sample_softmax(ctx, candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { + const int64_t t_start_sample_us = ggml_time_us(); + + // Find max element + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + if (ctx) { + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_get_sampling(ctx)->n_sample++; + } + return result; +} + +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { + GGML_ASSERT(ctx); + + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates); + + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + llama_token result = candidates->data[idx].id; + + llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_get_sampling(ctx)->n_sample++; + return result; +} + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, llama_get_sampling(ctx)->rng); +} + diff --git a/src/llama-sampling.h b/src/llama-sampling.h new file mode 100644 index 0000000000000..cf05e357e363c --- /dev/null +++ b/src/llama-sampling.h @@ -0,0 +1,21 @@ +#pragma once + +#include "llama-impl.h" + +struct llama_sampling { + llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + + std::mt19937 rng; + + int64_t t_sample_us = 0; + + int32_t n_sample = 0; + int32_t n_vocab = 0; + + void reset_timings() { + t_sample_us = 0; + n_sample = 0; + } +}; + +struct llama_sampling * llama_get_sampling(struct llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp index 4a9903cc33e47..6084f22f92518 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,5 +1,5 @@ -#define LLAMA_API_INTERNAL -#include "llama.h" +#include "llama-impl.h" +#include "llama-sampling.h" #include "unicode.h" @@ -32,19 +32,6 @@ // TODO: replace with ggml API call #define QK_K 256 -#ifdef __has_include - #if __has_include() - #include - #if defined(_POSIX_MAPPED_FILES) - #include - #include - #endif - #if defined(_POSIX_MEMLOCK_RANGE) - #include - #endif - #endif -#endif - #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX @@ -88,7 +75,6 @@ #include #include #include -#include #include #include #include @@ -100,33 +86,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -#ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) -#endif - -// bump if necessary -#define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_LAYERS 256 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 - -// -// logging -// - -LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (ggml_log_level level, const char * format, ...); -static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); - -#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) -#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) -#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) - // // helpers // @@ -2733,7 +2692,8 @@ struct llama_model { }; struct llama_context { - llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} + llama_context(const llama_model & model) : model(model), sampling(llama_n_vocab(&model)), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} + ~llama_context() { ggml_backend_sched_free(sched); @@ -2744,7 +2704,14 @@ struct llama_context { ggml_backend_buffer_free(buf_output); } - llama_cparams cparams; + const struct llama_model & model; + + struct llama_cparams cparams; + struct llama_sampling sampling; + struct llama_kv_cache kv_self; + struct llama_control_vector cvec; + + std::unordered_map lora_adapters; std::vector backends; #ifdef GGML_USE_METAL @@ -2755,26 +2722,16 @@ struct llama_context { #endif ggml_backend_t backend_cpu = nullptr; - - const llama_model & model; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; - - std::mt19937 rng; - bool has_evaluated_once = false; int64_t t_start_us; int64_t t_load_us; - int64_t t_sample_us = 0; int64_t t_p_eval_us = 0; int64_t t_eval_us = 0; int64_t t_compute_start_us = 0; int64_t n_queued_tokens = 0; - int32_t n_sample = 0; // number of tokens sampled int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) int32_t n_eval = 0; // number of eval calls @@ -2830,12 +2787,6 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] - - // control vectors - struct llama_control_vector cvec; - - // lora adapters and scales - std::unordered_map lora_adapters; }; struct llama_lora_weight { @@ -17015,469 +16966,7 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) return result; } -// -// sampling -// - -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - ctx->rng.seed(seed); -} - -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); - - const int64_t t_start_sample_us = ggml_time_us(); - - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } - - float max_l = candidates->data[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float p = expf(candidates->data[i].logit - max_l); - candidates->data[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { - // return; - // } - - const int64_t t_start_sample_us = ggml_time_us(); - - if (k <= 0) { - k = candidates->size; - } - - k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); - - // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } else { - constexpr int nbuckets = 128; - constexpr float bucket_low = -10.0f; - constexpr float bucket_high = 10.0f; - constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); - constexpr float bucker_inter = -bucket_low * bucket_scale; - - std::vector bucket_idx(candidates->size); - std::vector histo(nbuckets, 0); - - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; - int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); - bucket_idx[i] = ib; - ++histo[ib]; - } - int nhave = 0; - int ib = nbuckets - 1; - for ( ; ib >= 0; --ib) { - nhave += histo[ib]; - if (nhave >= k) break; - } - std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); - std::vector bucket_ptrs; - bucket_ptrs.reserve(nbuckets - ib); - for (int j = nbuckets - 1; j >= ib; --j) { - bucket_ptrs.push_back(ptr); - ptr += histo[j]; - } - for (int i = 0; i < (int)candidates->size; ++i) { - int j = bucket_idx[i]; - if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; - } - } - - ptr = tmp_tokens.data(); - int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { - std::sort(ptr, ptr + histo[j], comp); - ptr += histo[j]; - ndone += histo[j]; - } - std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); - - } - candidates->sorted = true; - } - candidates->size = k; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p >= 1.0f) { - return; - } - - llama_sample_softmax(ctx, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - - for (size_t i = 0; i < candidates->size; ++i) { - cum_sum += candidates->data[i].p; - - // Check if the running sum is at least p or if we have kept at least min_keep tokens - // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the top-p tokens - candidates->size = last_idx; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p <= 0.0f || !candidates->size) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - bool min_p_applied = false; - - // if the candidates aren't sorted, try the unsorted implementation first - if (!candidates->sorted) { - std::vector filtered_tokens; - - float max_logit = -FLT_MAX; - for (size_t i = 0; i < candidates->size; ++i) { - max_logit = std::max(max_logit, candidates->data[i].logit); - } - const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max - - for (size_t i = 0; i < candidates->size; ++i) { - if (candidates->data[i].logit >= min_logit) { - filtered_tokens.push_back(candidates->data[i]); - } - } - - // if we have enough values the operation was a success - if (filtered_tokens.size() >= min_keep) { - memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); - candidates->size = filtered_tokens.size(); - min_p_applied = true; - } - } - - // if the candidates are sorted or the unsorted implementation failed, use this implementation - if (!min_p_applied) { - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } - - const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max - size_t i = 1; // first token always matches - - for (; i < candidates->size; ++i) { - if (candidates->data[i].logit < min_logit && i >= min_keep) { - break; // prob too small - } - } - - // Resize the output vector to keep only the matching tokens - candidates->size = i; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - if (z >= 1.0f || candidates->size <= 2) { - return; - } - - llama_sample_softmax(nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - - // Compute the first and second derivatives - std::vector first_derivatives(candidates->size - 1); - std::vector second_derivatives(candidates->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } - - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; - - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) { - last_idx = i; - break; - } - } - - // Resize the output vector to keep only the tokens above the tail location - candidates->size = last_idx; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - // Reference implementation: - // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) { - return; - } - - // Compute the softmax of logits and calculate entropy - llama_sample_softmax(nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); - - float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - entropy += -candidates->data[i].p * logf(candidates->data[i].p); - } - - // Compute the absolute difference between negative log probability and entropy for each candidate - std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); - shifted_scores.push_back(shifted_score); - } - - // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(candidates->size); - std::iota(indices.begin(), indices.end(), 0); - - std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { - return shifted_scores[a] < shifted_scores[b]; - }); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = indices.size(); - - for (size_t i = 0; i < indices.size(); ++i) { - size_t idx = indices[i]; - cum_sum += candidates->data[idx].p; - - // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; - for (size_t i = 0; i < last_idx; ++i) { - size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); - } - - // Replace the data in candidates with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); - candidates->size = new_candidates.size(); - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); - - // no need to do anything if there is only one (or zero) candidates - if(candidates_p->size <= 1) { - return; - } - - // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates_p->size); - - llama_sample_softmax(nullptr, candidates_p); - - // Calculate entropy of the softmax probabilities - float entropy = 0.0f; - for (size_t i = 0; i < candidates_p->size; ++i) { - float prob = candidates_p->data[i].p; - if (prob > 0.0f) { // Ensure no log(0) - entropy -= prob * logf(prob); - } - } - - // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above) - float normalized_entropy = entropy / max_entropy; - - // Map the normalized entropy to the desired temperature range using the power function - float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); - -#ifdef DEBUG - LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); - LLAMA_LOG_INFO("Entropy: %f\n", entropy); - LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); - LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); - LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); - LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); -#endif - - // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= dyn_temp; - } - - // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates_p->data[0].logit; - double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates_p->size; ++i) { - double p = exp(candidates_p->data[i].logit - max_l_double); - candidates_p->data[i].p = p; // Store the scaled probability - cum_sum_double += p; - } - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities - } - -#ifdef DEBUG - // Print the updated top 25 probabilities after temperature scaling - LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f); - } -#endif - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= temp; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[last_tokens[i]]++; - } - - // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) { - const auto token_iter = token_count.find(candidates->data[i].id); - if (token_iter == token_count.end()) { - continue; - } - - const int count = token_iter->second; - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty_repeat; - } else { - candidates->data[i].logit /= penalty_repeat; - } - - candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; - } - - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - +// TODO: rename to llama_grammar_... void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { GGML_ASSERT(ctx); int64_t t_start_sample_us = ggml_time_us(); @@ -17517,7 +17006,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c candidates->data[reject.index].logit = -INFINITY; } - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + // TODO: change to t_grammar_us + ctx->sampling.t_sample_us += ggml_time_us() - t_start_sample_us; } static void llama_log_softmax(float * array, size_t size) { @@ -17534,158 +17024,6 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - GGML_ASSERT(ctx); - - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - - llama_log_softmax(logits, n_vocab); - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - auto & l = logits[i]; - const auto & g = logits_guidance[i]; - - l = scale * (l - g) + g; - } - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; -} - -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(ctx); - - auto N = float(llama_n_vocab(llama_get_model(ctx))); - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax(nullptr, candidates); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_sample_top_k(nullptr, candidates, int(k), 1); - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token(ctx, candidates); - t_start_sample_us = ggml_time_us(); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - *mu = *mu - eta * e; - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - return X; -} - -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax(ctx, candidates); - - // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; - })); - - if (candidates->size == 0) { - candidates->size = 1; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - - // Normalize the probabilities of the remaining words - llama_sample_softmax(ctx, candidates); - - // Sample the next word X from the remaining words - llama_token X = llama_sample_token(ctx, candidates); - t_start_sample_us = ggml_time_us(); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - *mu = *mu - eta * e; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - return X; -} - -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - - // Find max element - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - } - return result; -} - -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(ctx); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates); - - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); - - llama_token result = candidates->data[idx].id; - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - return result; -} - -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng(ctx, candidates, ctx->rng); -} - void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); @@ -17711,7 +17049,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar grammar->partial_utf8 = decoded.second; GGML_ASSERT(!grammar->stacks.empty()); - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->sampling.t_sample_us += ggml_time_us() - t_start_sample_us; } // @@ -19096,8 +18434,8 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + ctx->sampling.rng = std::mt19937(params.seed); + ctx->logits_all = params.logits_all; uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -19349,10 +18687,14 @@ void llama_free(struct llama_context * ctx) { delete ctx; } -const llama_model * llama_get_model(const struct llama_context * ctx) { +const struct llama_model * llama_get_model(const struct llama_context * ctx) { return &ctx->model; } +struct llama_sampling * llama_get_sampling(struct llama_context * ctx) { + return &ctx->sampling; +} + uint32_t llama_n_ctx(const struct llama_context * ctx) { return ctx->cparams.n_ctx; } @@ -19941,7 +19283,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data // copy rng { std::ostringstream rng_ss; - rng_ss << ctx->rng; + rng_ss << ctx->sampling.rng; const std::string & rng_str = rng_ss.str(); const size_t rng_size = rng_str.size(); @@ -20107,7 +19449,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { std::string rng_str((const char *)inp, rng_size); inp += rng_size; std::istringstream rng_ss(rng_str); - rng_ss >> ctx->rng; + rng_ss >> ctx->sampling.rng; GGML_ASSERT(!rng_ss.fail()); } @@ -21678,11 +21020,11 @@ struct llama_timings llama_get_timings(struct llama_context * ctx) { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, + /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - /*.n_sample =*/ std::max(1, ctx->n_sample), + /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -21705,10 +21047,11 @@ void llama_print_timings(struct llama_context * ctx) { } void llama_reset_timings(struct llama_context * ctx) { - ctx->t_start_us = ggml_time_us(); - ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; + + ctx->sampling.reset_timings(); } const char * llama_print_system_info(void) { @@ -21755,20 +21098,20 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->t_sample_us / ctx->n_sample); + 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); + fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); + fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->n_sample / ctx->t_sample_us); + 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); } // For internal test use @@ -21805,14 +21148,14 @@ static void llama_log_internal_v(ggml_log_level level, const char * format, va_l va_end(args_copy); } -static void llama_log_internal(ggml_log_level level, const char * format, ...) { +void llama_log_internal(ggml_log_level level, const char * format, ...) { va_list args; va_start(args, format); llama_log_internal_v(level, format, args); va_end(args); } -static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { (void) level; (void) user_data; fputs(text, stderr);