From 3ee401c66dfa38d4b467ef8afe4e3e57ce9d3d52 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 8 Sep 2024 15:52:07 +0200 Subject: [PATCH] llama : refactor samplers internal implementation (#9370) --- src/llama-impl.h | 4 + src/llama-sampling.cpp | 1502 ++++++++++++++++++++++----------------- src/llama-sampling.h | 10 - tests/test-sampling.cpp | 10 +- 4 files changed, 841 insertions(+), 685 deletions(-) diff --git a/src/llama-impl.h b/src/llama-impl.h index fa2e09e1f688e..87012617feed1 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -101,6 +101,10 @@ struct ring_buffer { } void push_back(const T & value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + if (sz == capacity) { // advance the start when buffer is full first = (first + 1) % capacity; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 1661d9a83ec80..41f48ec286779 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -13,12 +13,40 @@ #include static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { +#if 1 probs.resize(cur_p->size); for (size_t i = 0; i < cur_p->size; ++i) { probs[i] = cur_p->data[i].p; } std::discrete_distribution dist(probs.begin(), probs.end()); +#else + // avoid the copy with a custom iterator + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wunused-local-typedefs" + + struct probs_iterator { + typedef std::input_iterator_tag iterator_category; + typedef float value_type; + typedef float * pointer; + typedef float & reference; + typedef size_t difference_type; + + const llama_token_data_array * data; + size_t i; + + bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; } + bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; } + float operator*() const { return data->data[i].p; } + probs_iterator & operator++() { ++i; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; } + }; + #pragma GCC diagnostic pop + + std::discrete_distribution dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size}); + + GGML_UNUSED(probs); +#endif return dist(rng); } @@ -138,301 +166,6 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } -static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { - if (p >= 1.0f) { - return; - } - - llama_sampler_softmax_impl(cur_p); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = cur_p->size; - - for (size_t i = 0; i < cur_p->size; ++i) { - cum_sum += cur_p->data[i].p; - - // Check if the running sum is at least p or if we have kept at least min_keep tokens - // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the top-p tokens - cur_p->size = last_idx; -} - -static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { - if (p <= 0.0f || !cur_p->size) { - return; - } - - bool min_p_applied = false; - - // if the cur_p aren't sorted, try the unsorted implementation first - if (!cur_p->sorted) { - std::vector filtered_tokens; - - float max_logit = -FLT_MAX; - for (size_t i = 0; i < cur_p->size; ++i) { - max_logit = std::max(max_logit, cur_p->data[i].logit); - } - const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max - - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit >= min_logit) { - filtered_tokens.push_back(cur_p->data[i]); - } - } - - // if we have enough values the operation was a success - if (filtered_tokens.size() >= min_keep) { - memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); - cur_p->size = filtered_tokens.size(); - min_p_applied = true; - } - } - - // if the cur_p are sorted or the unsorted implementation failed, use this implementation - if (!min_p_applied) { - // Sort the logits in descending order - if (!cur_p->sorted) { - std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - cur_p->sorted = true; - } - - const float min_logit = cur_p->data[0].logit + logf(p); // min logit for p_i >= p * p_max - size_t i = 1; // first token always matches - - for (; i < cur_p->size; ++i) { - if (cur_p->data[i].logit < min_logit && i >= min_keep) { - break; // prob too small - } - } - - // Resize the output vector to keep only the matching tokens - cur_p->size = i; - } -} - -static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { - if (z >= 1.0f || cur_p->size <= 2) { - return; - } - - llama_sampler_softmax_impl(cur_p); - - // Compute the first and second derivatives - std::vector first_derivatives(cur_p->size - 1); - std::vector second_derivatives(cur_p->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } - - float cum_sum = 0.0f; - size_t last_idx = cur_p->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; - - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) { - last_idx = i; - break; - } - } - - // Resize the output vector to keep only the tokens above the tail location - cur_p->size = last_idx; -} - -static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { - // Reference implementation: - // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) { - return; - } - - // Compute the softmax of logits and calculate entropy - llama_sampler_softmax_impl(cur_p); - - float entropy = 0.0f; - for (size_t i = 0; i < cur_p->size; ++i) { - entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); - } - - // Compute the absolute difference between negative log probability and entropy for each candidate - std::vector shifted_scores; - for (size_t i = 0; i < cur_p->size; ++i) { - float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); - shifted_scores.push_back(shifted_score); - } - - // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(cur_p->size); - std::iota(indices.begin(), indices.end(), 0); - - std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { - return shifted_scores[a] < shifted_scores[b]; - }); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = indices.size(); - - for (size_t i = 0; i < indices.size(); ++i) { - size_t idx = indices[i]; - cum_sum += cur_p->data[idx].p; - - // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the locally typical tokens - std::vector cur_p_new; - for (size_t i = 0; i < last_idx; ++i) { - size_t idx = indices[i]; - cur_p_new.push_back(cur_p->data[idx]); - } - - // Replace the data in cur_p with the cur_p_new data - std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); - cur_p->size = cur_p_new.size(); - cur_p->sorted = false; -} - -static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { - // no need to do anything if there is only one (or zero) candidates - if (cur_p->size <= 1) { - return; - } - - // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / cur_p->size); - - llama_sampler_softmax_impl(cur_p); - - // Calculate entropy of the softmax probabilities - float entropy = 0.0f; - for (size_t i = 0; i < cur_p->size; ++i) { - float prob = cur_p->data[i].p; - if (prob > 0.0f) { // Ensure no log(0) - entropy -= prob * logf(prob); - } - } - - // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) - float normalized_entropy = entropy / max_entropy; - - // Map the normalized entropy to the desired temperature range using the power function - float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); - -#ifdef DEBUG - LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); - LLAMA_LOG_INFO("Entropy: %f\n", entropy); - LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); - LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); - LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); - LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); -#endif - - // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= dyn_temp; - } - - // Re-compute softmax probabilities after scaling logits with dynamic temperature - const double max_l_double = cur_p->data[0].logit; - - double cum_sum_double = 0.0; - for (size_t i = 0; i < cur_p->size; ++i) { - double p = exp(cur_p->data[i].logit - max_l_double); - cur_p->data[i].p = p; // Store the scaled probability - cum_sum_double += p; - } - - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities - } - -#ifdef DEBUG - // Print the updated top 25 probabilities after temperature scaling - LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); - } -#endif -} - -static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= temp; - } -} - -static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { - llama_grammar_apply_impl(grammar, cur_p); -} - -void llama_sampler_penalties_impl( - llama_token_data_array * cur_p, - const llama_token_cnt & token_count, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - // Apply frequency and presence penalties to the cur_p - for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); - if (token_iter == token_count.end()) { - continue; - } - - const int count = token_iter->second; - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (cur_p->data[i].logit <= 0) { - cur_p->data[i].logit *= penalty_repeat; - } else { - cur_p->data[i].logit /= penalty_repeat; - } - - cur_p->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; - } - - cur_p->sorted = false; -} - // llama_sampler API const char * llama_sampler_name(const struct llama_sampler * smpl) { @@ -600,17 +333,23 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy +static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { + return "greedy"; +} + +static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; }, + /* .name = */ llama_sampler_greedy_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { - cur_p->selected = 0; - for (size_t i = 1; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { - cur_p->selected = i; - } - } - }, + /* .apply = */ llama_sampler_greedy_apply, /* .reset = */ nullptr, /* .clone = */ nullptr, /* .free = */ nullptr, @@ -633,30 +372,45 @@ struct llama_sampler_dist { std::vector probs; // work array }; -static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; }, - /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); - }, - /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_dist *) smpl->ctx; - auto * result = llama_sampler_init_dist(ctx->seed); +static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { + return "dist"; +} - // copy the state - { - auto * result_ctx = (llama_sampler_dist *) result->ctx; +static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); +} - result_ctx->rng = ctx->rng; - } +static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_dist *) smpl->ctx; + auto * result = llama_sampler_init_dist(ctx->seed); - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_dist *) smpl->ctx; - }, + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->rng = std::mt19937(ctx->seed); +} + +static void llama_sampler_dist_free(struct llama_sampler * smpl) { + delete (llama_sampler_dist *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -672,12 +426,18 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { // softmax +static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) { + return "softmax"; +} + +static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + llama_sampler_softmax_impl(cur_p); +} + static struct llama_sampler_i llama_sampler_softmax_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; }, + /* .name = */ llama_sampler_softmax_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { - llama_sampler_softmax_impl(cur_p); - }, + /* .apply = */ llama_sampler_softmax_apply, /* .reset = */ nullptr, /* .clone = */ nullptr, /* .free = */ nullptr, @@ -696,21 +456,31 @@ struct llama_sampler_top_k { const int32_t k; }; +static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { + return "top-k"; +} + +static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); +} + +static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; + return llama_sampler_init_top_k(ctx->k); +} + +static void llama_sampler_top_k_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_k *) smpl->ctx; +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; }, + /* .name = */ llama_sampler_top_k_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_k *) smpl->ctx; - llama_sampler_top_k_impl(cur_p, ctx->k); - }, + /* .apply = */ llama_sampler_top_k_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; - return llama_sampler_init_top_k(ctx->k); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_top_k *) smpl->ctx; - }, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { @@ -729,21 +499,54 @@ struct llama_sampler_top_p { const size_t min_keep; }; +static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { + return "top-p"; +} + +static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_top_p *) smpl->ctx; + + if (ctx->p >= 1.0f) { + return; + } + + llama_sampler_softmax_impl(cur_p); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = cur_p->size; + + for (size_t i = 0; i < cur_p->size; ++i) { + cum_sum += cur_p->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + cur_p->size = last_idx; +} + +static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_top_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_top_p *) smpl->ctx; +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; }, + /* .name = */ llama_sampler_top_p_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_top_p *) smpl->ctx; - llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep); - }, + /* .apply = */ llama_sampler_top_p_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; - return llama_sampler_init_top_p(ctx->p, ctx->min_keep); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_top_p *) smpl->ctx; - }, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -763,21 +566,83 @@ struct llama_sampler_min_p { const size_t min_keep; }; +static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { + return "min-p"; +} + +static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_min_p *) smpl->ctx; + + if (ctx->p <= 0.0f || !cur_p->size) { + return; + } + + bool min_p_applied = false; + + // if the cur_p aren't sorted, try the unsorted implementation first + if (!cur_p->sorted) { + std::vector filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < cur_p->size; ++i) { + max_logit = std::max(max_logit, cur_p->data[i].logit); + } + const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit >= min_logit) { + filtered_tokens.push_back(cur_p->data[i]); + } + } + + // if we have enough values the operation was a success + if (filtered_tokens.size() >= ctx->min_keep) { + memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + cur_p->size = filtered_tokens.size(); + min_p_applied = true; + } + } + + // if the cur_p are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + cur_p->sorted = true; + } + + const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + cur_p->size = i; + } +} + +static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); +} + +static void llama_sampler_min_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_min_p *) smpl->ctx; +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; }, + /* .name = */ llama_sampler_min_p_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_min_p *) smpl->ctx; - llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep); - }, + /* .apply = */ llama_sampler_min_p_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; - return llama_sampler_init_min_p(ctx->p, ctx->min_keep); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_min_p *) smpl->ctx; - }, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -797,21 +662,82 @@ struct llama_sampler_tail_free { const size_t min_keep; }; +static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) { + return "tail-free"; +} + +static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_tail_free *) smpl->ctx; + + if (ctx->z >= 1.0f || cur_p->size <= 2) { + return; + } + + llama_sampler_softmax_impl(cur_p); + + // Compute the first and second derivatives + std::vector first_derivatives(cur_p->size - 1); + std::vector second_derivatives(cur_p->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = std::abs(second_derivatives[i]); + } + + // Normalize the second derivatives + { + const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + + if (second_derivatives_sum > 1e-6f) { + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + } else { + for (float & value : second_derivatives) { + value = 1.0f / second_derivatives.size(); + } + } + } + + float cum_sum = 0.0f; + size_t last_idx = cur_p->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > ctx->z && i >= ctx->min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + cur_p->size = last_idx; +} + +static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx; + return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); +} + +static void llama_sampler_tail_free_free(struct llama_sampler * smpl) { + delete (llama_sampler_tail_free *) smpl->ctx; +} + static struct llama_sampler_i llama_sampler_tail_free_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; }, + /* .name = */ llama_sampler_tail_free_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_tail_free *) smpl->ctx; - llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep); - }, + /* .apply = */ llama_sampler_tail_free_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx; - return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_tail_free *) smpl->ctx; - }, + /* .clone = */ llama_sampler_tail_free_clone, + /* .free = */ llama_sampler_tail_free_free, }; struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { @@ -831,21 +757,86 @@ struct llama_sampler_typical { const size_t min_keep; }; +static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) { + return "typical"; +} + +static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_typical *) smpl->ctx; + + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (ctx->p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sampler_softmax_impl(cur_p); + + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < cur_p->size; ++i) { + float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(cur_p->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += cur_p->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector cur_p_new; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + cur_p_new.push_back(cur_p->data[idx]); + } + + // Replace the data in cur_p with the cur_p_new data + std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); + cur_p->size = cur_p_new.size(); + cur_p->sorted = false; +} + +static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_typical *) smpl->ctx; + return llama_sampler_init_typical(ctx->p, ctx->min_keep); +} + +static void llama_sampler_typical_free(struct llama_sampler * smpl) { + delete (llama_sampler_typical *) smpl->ctx; +} + static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; }, + /* .name = */ llama_sampler_typical_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_typical *) smpl->ctx; - llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep); - }, + /* .apply = */ llama_sampler_typical_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_typical *) smpl->ctx; - return llama_sampler_init_typical(ctx->p, ctx->min_keep); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_typical *) smpl->ctx; - }, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -858,27 +849,39 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { }; } -// temp +// temp + +struct llama_sampler_temp { + const float temp; +}; + +static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { + return "temp"; +} + +static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp *) smpl->ctx; + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= ctx->temp; + } +} + +static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp *) smpl->ctx; + return llama_sampler_init_temp(ctx->temp); +} -struct llama_sampler_temp { - const float temp; -}; +static void llama_sampler_temp_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp *) smpl->ctx; +} static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; }, + /* .name = */ llama_sampler_temp_name, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_temp *) smpl->ctx; - llama_sampler_temp_impl(cur_p, ctx->temp); - }, + /* .apply = */ llama_sampler_temp_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_temp *) smpl->ctx; - return llama_sampler_init_temp(ctx->temp); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_temp *) smpl->ctx; - }, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -898,28 +901,100 @@ struct llama_sampler_temp_ext { const float exponent; }; -static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; }, - /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; - if (ctx->delta > 0) { - const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); - const float temp_max = ctx->temp + ctx->delta; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { + return "temp-ext"; +} - llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); - } else { - llama_sampler_temp_impl(cur_p, ctx->temp); +static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; + if (ctx->delta > 0) { + const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); + const float max_temp = ctx->temp + ctx->delta; + float exponent_val = ctx->exponent; + + // no need to do anything if there is only one (or zero) candidates + if (cur_p->size <= 1) { + return; } - }, + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / cur_p->size); + + llama_sampler_softmax_impl(cur_p); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < cur_p->size; ++i) { + float prob = cur_p->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + + #ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); + #endif + + // Apply the dynamically calculated temperature scaling + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= dyn_temp; + } + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + const double max_l_double = cur_p->data[0].logit; + + double cum_sum_double = 0.0; + for (size_t i = 0; i < cur_p->size; ++i) { + double p = exp(cur_p->data[i].logit - max_l_double); + cur_p->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + + #ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); + } + #endif + } else { + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= ctx->temp; + } + } +} + +static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); +} + +static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { + delete (llama_sampler_temp_ext *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; - return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_temp_ext *) smpl->ctx; - }, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -952,65 +1027,77 @@ struct llama_sampler_mirostat { std::vector probs; }; -static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, - /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_mirostat *) smpl->ctx; +static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { + return "mirostat"; +} - llama_sampler_softmax_impl(cur_p); +static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; + llama_sampler_softmax_impl(cur_p); - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; - llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); - llama_sampler_softmax_impl(cur_p); + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_softmax_impl(cur_p); - cur_p->selected = idx; + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); - float observed_surprise = -log2f(cur_p->data[idx].p); - float e = observed_surprise - ctx->tau; + cur_p->selected = idx; - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_mirostat *) smpl->ctx; - ctx->mu = 2.0f*ctx->tau; - ctx->rng = std::mt19937(ctx->seed); - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; - auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; - // copy the state - { - auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; +} - result_ctx->mu = ctx->mu; - result_ctx->rng = ctx->rng; - } +static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_mirostat *) smpl->ctx; - }, + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); +} + +static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -1044,59 +1131,71 @@ struct llama_sampler_mirostat_v2 { std::vector probs; }; -static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, - /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; +static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { + return "mirostat-v2"; +} - llama_sampler_softmax_impl(cur_p); +static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; - // Truncate the words with surprise values greater than mu - cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > ctx->mu; - })); + llama_sampler_softmax_impl(cur_p); - if (cur_p->size == 0) { - cur_p->size = 1; - } + // Truncate the words with surprise values greater than mu + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; + })); - // Normalize the probabilities of the remaining words - llama_sampler_softmax_impl(cur_p); + if (cur_p->size == 0) { + cur_p->size = 1; + } - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + // Normalize the probabilities of the remaining words + llama_sampler_softmax_impl(cur_p); - cur_p->selected = idx; + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); - float observed_surprise = -log2f(cur_p->data[idx].p); - float e = observed_surprise - ctx->tau; + cur_p->selected = idx; - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; - ctx->mu = 2.0f*ctx->tau; - ctx->rng = std::mt19937(ctx->seed); - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; - auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; +} - // copy the state - { - auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; +static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; + ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); +} - result_ctx->mu = ctx->mu; - result_ctx->rng = ctx->rng; - } +static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_mirostat_v2 *) smpl->ctx; - }, + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { + delete (llama_sampler_mirostat_v2 *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1124,59 +1223,73 @@ struct llama_sampler_grammar { struct llama_grammar * grammar; }; -static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - const auto * ctx = (llama_sampler_grammar *) smpl->ctx; - if (ctx->grammar) { - llama_grammar_accept_impl(*ctx->grammar, token); - } - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_grammar *) smpl->ctx; - if (ctx->grammar) { - llama_sampler_grammar_impl(cur_p, *ctx->grammar); - } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_grammar *) smpl->ctx; - if (!ctx->grammar) { - return; - } +static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) { + return "grammar"; +} - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); +static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } +} - llama_grammar_free_impl(ctx->grammar); - ctx->grammar = grammar_new; - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; +static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + llama_grammar_apply_impl(*ctx->grammar, cur_p); + } +} - auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); +static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (!ctx->grammar) { + return; + } - // copy the state - { - auto * result_ctx = (llama_sampler_grammar *) result->ctx; + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); - if (ctx->grammar) { - result_ctx->grammar_str = ctx->grammar_str; - result_ctx->grammar_root = ctx->grammar_root; + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; +} - result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); - } - } +static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - const auto * ctx = (llama_sampler_grammar *) smpl->ctx; + auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); + + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; if (ctx->grammar) { - llama_grammar_free_impl(ctx->grammar); + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; + + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); } + } - delete ctx; - }, + return result; +} + +static void llama_sampler_grammar_free(struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; + + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + } + + delete ctx; +} + +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, }; struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { @@ -1222,106 +1335,144 @@ struct llama_sampler_penalties { ring_buffer prev; }; -static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_penalties *) smpl->ctx; - if (ctx->prev.size()) { - ctx->prev.push_back(token); - } - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_penalties *) smpl->ctx; - - if (ctx->ignore_eos) { - assert(ctx->special_eos_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { - cur_p->data[ctx->special_eos_id].logit = -INFINITY; - } else { - // else, search for the special EOS token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->special_eos_id) { - cur_p->data[i].logit = -INFINITY; - break; - } +static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { + return "penalties"; +} + +static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + if (ctx->penalty_last_n == 0) { + return; + } + + ctx->prev.push_back(token); +} + +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + + if (ctx->ignore_eos) { + assert(ctx->special_eos_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { + cur_p->data[ctx->special_eos_id].logit = -INFINITY; + } else { + // else, search for the special EOS token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->special_eos_id) { + cur_p->data[i].logit = -INFINITY; + break; } } } + } - if ((ctx->penalty_last_n == 0) || - (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { - return; - } + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } - bool nl_found = false; - size_t nl_idx = 0; - float nl_logit = -INFINITY; - if (!ctx->penalize_nl) { - assert(ctx->linefeed_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = ctx->linefeed_id; - nl_logit = cur_p->data[ctx->linefeed_id].logit; - } else { - // else, search for the linefeed token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = i; - nl_logit = cur_p->data[i].logit; - break; - } + bool nl_found = false; + size_t nl_idx = 0; + float nl_logit = -INFINITY; + if (!ctx->penalize_nl) { + assert(ctx->linefeed_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = ctx->linefeed_id; + nl_logit = cur_p->data[ctx->linefeed_id].logit; + } else { + // else, search for the linefeed token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = i; + nl_logit = cur_p->data[i].logit; + break; } } } + } - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the sampler context - llama_token_cnt token_count; - for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { - token_count[ctx->prev.rat(i)]++; - } + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: optimize this by maintaining the token count in the sampler context + using llama_token_cnt = std::unordered_map; + llama_token_cnt token_count; - llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + token_count[ctx->prev.rat(i)]++; + } - if (!ctx->penalize_nl && nl_found) { - // restore the logit of the newline token if it was penalized - cur_p->data[nl_idx].logit = nl_logit; + // Apply frequency and presence penalties to the cur_p + for (size_t i = 0; i < cur_p->size; ++i) { + const auto token_iter = token_count.find(cur_p->data[i].id); + if (token_iter == token_count.end()) { + continue; } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_penalties *) smpl->ctx; - ctx->prev.clear(); - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; - auto * result = llama_sampler_init_penalties( - ctx->n_vocab, - ctx->special_eos_id, - ctx->linefeed_id, - ctx->penalty_last_n, - ctx->penalty_repeat, - ctx->penalty_freq, - ctx->penalty_present, - ctx->penalize_nl, - ctx->ignore_eos); - - // copy the state - { - auto * result_ctx = (llama_sampler_penalties *) result->ctx; - - result_ctx->prev = ctx->prev; + + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (cur_p->data[i].logit <= 0) { + cur_p->data[i].logit *= ctx->penalty_repeat; + } else { + cur_p->data[i].logit /= ctx->penalty_repeat; } - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_penalties *) smpl->ctx; - }, + cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present; + } + + cur_p->sorted = false; + + if (!ctx->penalize_nl && nl_found) { + // restore the logit of the newline token if it was penalized + cur_p->data[nl_idx].logit = nl_logit; + } +} + +static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; + ctx->prev.clear(); +} + +static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties( + ctx->n_vocab, + ctx->special_eos_id, + ctx->linefeed_id, + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present, + ctx->penalize_nl, + ctx->ignore_eos); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } + + return result; +} + +static void llama_sampler_penalties_free(struct llama_sampler * smpl) { + delete (llama_sampler_penalties *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1335,11 +1486,11 @@ struct llama_sampler * llama_sampler_init_penalties( bool penalize_nl, bool ignore_eos) { if (linefeed_id == LLAMA_TOKEN_NULL) { - penalize_nl = false; + penalize_nl = true; } if (special_eos_id == LLAMA_TOKEN_NULL) { - ignore_eos = true; + ignore_eos = false; } return new llama_sampler { @@ -1369,41 +1520,50 @@ struct llama_sampler_logit_bias { std::vector to_search; }; -static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; }, - /* .accept = */ nullptr, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { + return "logit-bias"; +} - ctx->to_search.clear(); +static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; - // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) - for (const auto & lb : ctx->logit_bias) { - if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { - cur_p->data[lb.token].logit += lb.bias; - } else { - ctx->to_search.push_back(lb); - } + ctx->to_search.clear(); + + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) + for (const auto & lb : ctx->logit_bias) { + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); } + } - // search for the remaining candidates that were not found in the previous step - for (size_t i = 0; i < cur_p->size; ++i) { - for (const auto & lb : ctx->to_search) { - if (cur_p->data[i].id == lb.token) { - cur_p->data[i].logit += lb.bias; - break; - } + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; } } - }, + } +} +static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); +} + +static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { + delete (llama_sampler_logit_bias *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); - }, - /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_logit_bias *) smpl->ctx; - }, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, }; struct llama_sampler * llama_sampler_init_logit_bias( diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 137c0025ce0d8..d90b147130e4b 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -23,16 +23,6 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; -using llama_token_cnt = std::unordered_map; - -// TODO: tmp exposed until test-sampling is fixed -void llama_sampler_penalties_impl( - llama_token_data_array * cur_p, - const llama_token_cnt & token_count, - float penalty_repeat, - float penalty_freq, - float penalty_present); - struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index cc4882d37579a..37400c179e9bd 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -148,15 +148,17 @@ static void test_penalties( cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_cnt token_count; + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false); + for (size_t i = 0; i < last_tokens.size(); i++) { - token_count[last_tokens[i]]++; + llama_sampler_accept(sampler, last_tokens[i]); } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p); - llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid + APPLY(sampler, &cur_p); APPLY(llama_sampler_init_softmax(), &cur_p); DUMP(&cur_p);