From 8307e96fbb2b2b1039347aff6ac91c9e8b7d9aff Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 18:07:47 +0300 Subject: [PATCH] sampling : improve mirostat implementation ggml-ci --- common/sampling.cpp | 22 +++---- common/sampling.h | 2 +- include/llama.h | 2 + src/llama-sampling.cpp | 133 ++++++++++++++++++++-------------------- src/llama-sampling.h | 2 + src/llama.cpp | 8 +-- tests/test-sampling.cpp | 14 ++--- 7 files changed, 95 insertions(+), 88 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index de7c9b1b973958..cf3ee98d4c7446 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,7 +121,7 @@ struct gpt_sampler { cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + cur_p = { cur.data(), cur.size(), -1, false }; } }; @@ -202,17 +202,17 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st GGML_ASSERT(false && "unknown sampler type"); } } + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else { llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); @@ -246,8 +246,8 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { }; } -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { - if (apply_grammar) { +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); } @@ -293,9 +293,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(chain, &cur_p); - const llama_token id = cur_p.data[cur_p.selected].id; + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); - GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); + const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { return id; @@ -304,7 +304,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; llama_sampler_apply(grmr, &single_token_data_array); @@ -324,7 +324,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(chain, &cur_p); - GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); return cur_p.data[cur_p.selected].id; } diff --git a/common/sampling.h b/common/sampling.h index 5083f456f1f96b..d88038204c89f4 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -70,7 +70,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl); struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); diff --git a/include/llama.h b/include/llama.h index 766ce637fe2a73..46c4c3b95b14ab 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1066,6 +1066,7 @@ extern "C" { /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( const struct llama_model * model, + uint32_t seed, float tau, float eta); @@ -1075,6 +1076,7 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( + uint32_t seed, float tau, float eta); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8ff52dd2decd1c..0084fe0b7dc4c3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -11,6 +11,17 @@ #include #include +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { + probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs[i] = cur_p->data[i].p; + } + + std::discrete_distribution dist(probs.begin(), probs.end()); + + return dist(rng); +} + static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -456,6 +467,8 @@ struct llama_sampler_context_dist { const uint32_t seed; std::mt19937 rng; + + std::vector probs; // work array }; static struct llama_sampler_i llama_sampler_dist_i = { @@ -463,15 +476,7 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_dist *) smpl->ctx; - std::vector probs; - probs.reserve(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs.push_back(cur_p->data[i].p); - } - - std::discrete_distribution dist(probs.begin(), probs.end()); - - cur_p->selected = dist(ctx->rng); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { @@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { /* .ctx = */ new llama_sampler_context_dist { /* .seed = */ seed, /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -761,6 +767,8 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, struct llama_sampler_context_mirostat { const struct llama_vocab * vocab; + const uint32_t seed; + const float tau; const float eta; @@ -768,28 +776,14 @@ struct llama_sampler_context_mirostat { float mu; - std::vector cur; + std::mt19937 rng; + + std::vector probs; }; static struct llama_sampler_i llama_sampler_mirostat_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; - - int32_t idx = -1; - for (size_t i = 0; i < ctx->cur.size(); ++i) { - if (ctx->cur[i].id == token) { - idx = i; - break; - } - } - - float observed_surprise = -log2f(ctx->cur[idx].p); - float e = observed_surprise - ctx->tau; - - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, + /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; @@ -812,36 +806,44 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_softmax_impl(cur_p); - // remember the order to be able to compute the distance later when accepting the token - ctx->cur.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - ctx->cur[i] = cur_p->data[i]; - } + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { +struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_context_mirostat { /* .vocab = */ &vocab, + /* .seed = */ seed, /* .tau = */ tau, /* .eta = */ eta, /* .m = */ m, /* .mu = */ 2.0f*tau, - /* .cur = */ {}, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -849,33 +851,21 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab // mirostat v2 struct llama_sampler_context_mirostat_v2 { + const uint32_t seed; + const float tau; const float eta; float mu; - std::vector cur; + std::mt19937 rng; + + std::vector probs; }; static struct llama_sampler_i llama_sampler_mirostat_v2_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; - - int32_t idx = -1; - for (size_t i = 0; i < ctx->cur.size(); ++i) { - if (ctx->cur[i].id == token) { - idx = i; - break; - } - } - - float observed_surprise = -log2f(ctx->cur[idx].p); - float e = observed_surprise - ctx->tau; - - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, + /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; @@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { // Normalize the probabilities of the remaining words llama_sampler_softmax_impl(cur_p); - // remember the order to be able to compute the distance later when accepting the token - ctx->cur.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - ctx->cur[i] = cur_p->data[i]; - } + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta); + return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) { +struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_context_mirostat_v2 { - /* .tau = */ tau, - /* .eta = */ eta, - /* .mu = */ 2.0f*tau, - /* .cur = */ {}, + /* .seed = */ seed, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl( static struct llama_sampler_i llama_sampler_chain_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) { + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { auto * chain = (llama_sampler_chain *) smpl->ctx; + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept_impl(*smpl, token); + } + chain->n_sample++; }, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 3f14ec621f5c14..0088060c8d971f 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -58,11 +58,13 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta struct llama_sampler * llama_sampler_init_mirostat_impl( const struct llama_vocab & vocab, + uint32_t seed, float tau, float eta, int32_t m); struct llama_sampler * llama_sampler_init_mirostat_v2_impl( + uint32_t seed, float tau, float eta); diff --git a/src/llama.cpp b/src/llama.cpp index 17d3d24301a871..3de9a8e68585ce 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20646,12 +20646,12 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa return llama_sampler_init_temp_ext_impl(temp, delta, exponent); } -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) { - return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100); +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) { + return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100); } -struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) { - return llama_sampler_init_mirostat_v2_impl(tau, eta); +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + return llama_sampler_init_mirostat_v2_impl(seed, tau, eta); } struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index adc1ff4e6da7de..cc4882d37579a6 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -35,7 +35,7 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; DUMP(&cur_p); APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); @@ -100,7 +100,7 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector