From bc48b216c80c18a76e25f761b9b0d6e4b03552e0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 12:22:27 +0300 Subject: [PATCH] style : rearrange code + add comments and TODOs ggml-ci --- common/sampling.cpp | 71 ++++++++++++++++++++++---------------------- common/sampling.h | 41 +++++++++++++++++++------ include/llama.h | 49 +++++++++++++++++++++++++++--- src/llama-sampling.h | 2 ++ 4 files changed, 115 insertions(+), 48 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a4baf9db600841..5f27d5006044f3 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const { return std::string(result); } -std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { - std::string result = "\tlogits "; - - for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { - const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string("-> ") + llama_sampler_name(smpl) + " "; - } - - return result; -} - struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } } -struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { - return new gpt_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, - }; -} - void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); @@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) { llama_sampler_reset(gsmpl->chain); } -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { - return &gsmpl->cur_p; -} - -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { - return gsmpl->prev.rat(0); +struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { + return new gpt_sampler { + /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, + }; } void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { @@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & grmr = gsmpl->grmr; - auto & chain = gsmpl->chain; - gsmpl->set_logits(ctx, idx); - auto & cur_p = gsmpl->cur_p; + auto & grmr = gsmpl->grmr; + auto & chain = gsmpl->chain; + auto & cur_p = gsmpl->cur_p; // initialized by set_logits if (grammar_first) { llama_sampler_apply(grmr, &cur_p); @@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(grmr, &single_token_data_array); - // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; if (is_valid) { return id; } } - // if the token is not valid, sample again, first apply the grammar samplers and then sample + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain gsmpl->set_logits(ctx, idx); llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); - GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); + GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); return cur_p.data[cur_p.selected].id; } +// helpers + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { + return &gsmpl->cur_p; +} + +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { + return gsmpl->prev.rat(0); +} + +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { + std::string result = "\tlogits "; + + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + result += std::string("-> ") + llama_sampler_name(smpl) + " "; + } + + return result; +} + std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { n = std::min(n, (int) gsmpl->prev.size()); diff --git a/common/sampling.h b/common/sampling.h index fa691cda234999..654e0c513904d8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -61,24 +61,41 @@ struct gpt_sampler_params { // // - grammar support // - custom sampler logic based on the parameters +// - history of the last accepted tokens +// - performance metrics +// +// This goal is to have a common implementation of the sampling logic shared across the examples. +// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more +// complex (top-k, top-p, etc). +// +// Another example is related to the grammar. In general, the grammar constraints applied on the full +// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled +// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the +// grammar constraints are applied to the full vocabulary and the token is resampled. +// +// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can +// be moved into the core llama library. +// +// For convenience, the gpt_sampler also maintains a container with the current candidate tokens. +// This can be used to access the probabilities of the rest of the non-sampled tokens. // // TODO: measure grammar performance // + struct gpt_sampler; +// llama_sampler API overloads + struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); void gpt_sampler_free(struct gpt_sampler * gsmpl); -struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); - -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); -void gpt_sampler_reset (struct gpt_sampler * gsmpl); - -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); - -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); +// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); +void gpt_sampler_reset (struct gpt_sampler * gsmpl); +struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl); +// arguments can be nullptr to skip printing void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); // extended sampling implementation: @@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // // if grammar_first is true, the grammar is applied before the samplers (slower) -// useful in cases where all the resulting candidates must fit the grammar +// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers +// access the internal list of current candidate tokens +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); + +// get the last accepted token +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); + // print the sampler chain into a string std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); diff --git a/include/llama.h b/include/llama.h index c49371ae8cf77d..dd914a634e1bf6 100644 --- a/include/llama.h +++ b/include/llama.h @@ -204,6 +204,7 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; + // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -214,7 +215,7 @@ extern "C" { // TODO: consider SoA llama_token_data * data; size_t size; - int64_t selected; + int64_t selected; // this is the index in the data array (i.e. not the token id) bool sorted; } llama_token_data_array; @@ -977,9 +978,38 @@ extern "C" { // // Sampling API // - // In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). + // Sample usage: + // + // // prepare the sampling chain at the start + // auto sparams = llama_sampler_chain_default_params(); + // + // llama_sampler * smpl = llama_sampler_chain_init(sparams); + // + // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); + // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); + // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); + // llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed)); + // + // ... + // + // // decoding loop: + // while (...) { + // ... + // + // llama_decode(ctx, batch); + // + // // sample from the logits of the last token in the batch + // const llama_token id = llama_sampler_sample(smpl, ctx, -1); + // + // ... + // } + // + // llama_sampler_free(smpl); + // + // + // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). + // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // - // TODO: in the future, the entire API that uses llama_model should start using llama_vocab typedef void * llama_sampler_context_t; @@ -1001,6 +1031,7 @@ extern "C" { llama_sampler_context_t ctx; }; + // mirror of llama_sampler_i: LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1009,7 +1040,8 @@ extern "C" { // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - // llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers + // llama_sampler_chain + // a type of llama_sampler that can chain multiple samplers one after another LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); @@ -1087,6 +1119,15 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); + // Shorthand for: + // + // const auto * logits = llama_get_logits_ith(ctx, idx); + // llama_token_data_array cur_p = { ... init from logits ... }; + // llama_sampler_apply(smpl, &cur_p); + // return cur_p.data[cur_p.selected].id; + // + // At this point, this is mostly a convenience function. + // LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 05bb294a10d2fa..ddc84a39006666 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,5 +1,7 @@ #pragma once +// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? + #include "llama-grammar.h" #include