From 25a46f2b9f29b96b71943d5d12b60fc9475488ed Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 17 Aug 2024 14:18:17 +0300 Subject: [PATCH] sampling : remove n_valid from the state ggml-ci --- common/common.cpp | 12 +++++------ common/sampling.cpp | 12 ++++------- common/sampling.h | 44 ++++++++++++++++++-------------------- examples/server/server.cpp | 34 ++++++----------------------- include/llama.h | 1 + 5 files changed, 39 insertions(+), 64 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6237770963ab3..7bf843eb4e854 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -249,16 +249,14 @@ void gpt_params_handle_model_default(gpt_params & params) { } bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - auto & sparams = params.sparams; - for (int i = 1; i < argc; i++) { - arg = argv[i]; + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { std::replace(arg.begin(), arg.end(), '_', '-'); } + bool invalid_param = false; if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { throw std::invalid_argument("error: unknown argument: " + arg); } @@ -275,6 +273,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { gpt_params_handle_hf_token(params); + auto & sparams = params.sparams; + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); diff --git a/common/sampling.cpp b/common/sampling.cpp index 24c317c7eacab..69ce4e9040341 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -40,8 +40,6 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data()); } - result->n_valid = 0; - return result; } @@ -55,7 +53,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { llama_sampling_reset(ctx->smpl); ctx->cur.clear(); - ctx->n_valid = 0; + ctx->org.clear(); } void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { @@ -294,11 +292,11 @@ static llama_token llama_sampling_sample( llama_token id = 0; - if (temp < 0.0) { + if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { // greedy sampling, with probs llama_sampling_softmax(smpl, cur_p); id = cur_p->data[0].id; - } else if (temp == 0.0) { + } else if (temp == 0.0f) { // greedy sampling, no probs id = llama_sampling_sample_greedy(smpl, cur_p); } else { @@ -325,8 +323,6 @@ static llama_token llama_sampling_sample( } } - ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p->size; - return id; } @@ -341,7 +337,7 @@ llama_token llama_sampling_sample( return llama_sampling_sample(ctx_sampling, &cur_p); } - // TODO: this lofic is confusing, try to figure out a better way to handle this + // TODO: this logic is confusing, try to figure out a better way to handle this // store the original candidates ctx_sampling->org = ctx_sampling->cur; diff --git a/common/sampling.h b/common/sampling.h index 5434540ecabb7..3a20754e14b7d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -17,27 +17,27 @@ enum class llama_sampler_type : char { // sampling parameters typedef struct gpt_sampling_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - bool ignore_eos = false; + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; std::vector samplers_sequence = { llama_sampler_type::TOP_K, @@ -68,8 +68,6 @@ struct llama_sampling_context { std::vector cur; std::vector org; - - size_t n_valid; // Number of correct top tokens with correct probabilities. }; // Create a new sampling context instance. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e4eedd23fe637..83d5bf47454f5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2350,35 +2350,15 @@ struct server_context { metrics.on_prompt_eval(slot); } - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); - if (n_probs > 0) { - const size_t n_valid = slot.ctx_sampling->n_valid; - - // Make sure at least n_probs top tokens are at the front of the vector: - // TODO: decide to how to handle this after the refactoring - //if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - // llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0); - //} - - if (slot.sparams.temp == 0.0f) { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i == 0 ? 1.0f : 0.0f - }); - } - } else { - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. - }); - } - } + const llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; + + for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i >= cur_p.size ? 0.0f : cur_p.data[i].p, + }); } if (!process_token(result, slot)) { diff --git a/include/llama.h b/include/llama.h index a019aa41caa17..dfc9178037cbd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -210,6 +210,7 @@ extern "C" { } llama_token_data; typedef struct llama_token_data_array { + // TODO: consider SoA llama_token_data * data; size_t size; bool sorted;