From 77998ae7805aa34f8de897d2d923318bf55402b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 29 Aug 2024 15:56:19 +0300 Subject: [PATCH] sampling : option to use internal set of candidates ggml-ci --- common/sampling.cpp | 22 ++++++++-- common/sampling.h | 2 +- examples/batched.swift/Sources/main.swift | 23 +++------- examples/infill/infill.cpp | 2 +- examples/llava/llava-cli.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/main/main.cpp | 2 +- src/llama.cpp | 52 +++++++++++++++++++++++ 8 files changed, 82 insertions(+), 25 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 4ceb7918aa551..50a52735c5cf9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -199,9 +199,25 @@ llama_token llama_sampling_sample( int idx) { llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampling_get_candidates(smpl); + // first, sample the token without any grammar constraints + auto id = llama_sampling_sample(smpl, nullptr); - llama_sampling_grammar(smpl, cur_p); + // create an array with a single token data element for the sampled id + llama_token_data single_token_data = {id, 1.0f, 0.0f}; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - return llama_sampling_sample(smpl, cur_p); + llama_sampling_grammar(smpl, &single_token_data_array); + + // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + + // if the token is not valid, sample again, after applying the grammar constraints + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + + llama_sampling_grammar(smpl, nullptr); + + return llama_sampling_sample(smpl, nullptr); } diff --git a/common/sampling.h b/common/sampling.h index 8a2e595ba2f97..6a9e7948be4c8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -67,4 +67,4 @@ std::vector llama_sampling_types_from_chars(const std:: llama_token llama_sampling_sample( struct llama_sampling * smpl, struct llama_context * ctx, - int idx = -1); + int idx); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 4595aa68b2fbf..81763217a91a8 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -136,28 +136,17 @@ while n_cur <= n_len { continue } - var n_vocab = llama_n_vocab(model) var logits = llama_get_logits_ith(context, i_batch[i]) - var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) + llama_sampling_set_logits(smpl, logits) - for token_id in 0 ..< n_vocab { - candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) - } - - var candidates_p: llama_token_data_array = .init( - data: &candidates, - size: candidates.count, - sorted: false - ) - - llama_sampling_top_k(smpl, &candidates_p) - llama_sampling_top_p(smpl, &candidates_p) - llama_sampling_temp (smpl, &candidates_p) + llama_sampling_top_k(smpl, nil) + llama_sampling_top_p(smpl, nil) + llama_sampling_temp (smpl, nil) - let new_token_id = llama_sampling_sample_dist(smpl, &candidates_p) + let new_token_id = llama_sampling_sample_dist(smpl, nil) - // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); + // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 2f661f00ebd69..371232421b71c 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -417,7 +417,7 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(smpl, ctx); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); llama_sampling_accept(smpl, id, true); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 774b7b779d4cd..29f6c462aee8f 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -43,7 +43,7 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama); + const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index c056cadc63ae2..0e164ab5c9d73 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -166,7 +166,7 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama); + const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 834d2279834b7..ad76e3b7e0e69 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -650,7 +650,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(smpl, ctx); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); llama_sampling_accept(smpl, id, /* apply_grammar= */ true); diff --git a/src/llama.cpp b/src/llama.cpp index 564aa6db41958..207cf412210f9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20139,42 +20139,70 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * s void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_softmax_impl(candidates); } void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); } void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); } void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); } void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); } void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); } void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + if (smpl->params.dynatemp_range > 0) { const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); @@ -20188,6 +20216,10 @@ void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_grammar_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + if (smpl->grammar) { llama_sampling_grammar_impl(candidates, *smpl->grammar); @@ -20200,6 +20232,10 @@ void llama_sampling_penalties( llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); const float penalty_repeat = smpl->params.penalty_repeat; @@ -20224,6 +20260,10 @@ void llama_sampling_penalties( llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + const auto type = smpl->params.mirostat; llama_token res; @@ -20254,6 +20294,10 @@ llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_t llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + auto res = llama_sampling_sample_greedy_impl(candidates); smpl->n_sample++; @@ -20264,6 +20308,10 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); smpl->n_sample++; @@ -20274,6 +20322,10 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + const auto & params = smpl->params; const float temp = params.temp;