From 1238001be882615d3a1949f460c51278ea8ce52c Mon Sep 17 00:00:00 2001 From: zhenweijin Date: Tue, 6 Aug 2024 18:01:51 +0800 Subject: [PATCH] Use ring buffer to store prev in sampling --- common/sampling.cpp | 10 ++-- common/sampling.h | 104 ++++++++++++++++++++++++++++++++++++- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 2 +- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff2..ae214020d7f4b2 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -37,7 +37,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->grammar = grammar; } - result->prev.resize(params.n_prev); + result->prev = ring_buffer(params.n_prev); result->n_valid = 0; @@ -72,7 +72,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->grammar = grammar; } - std::fill(ctx->prev.begin(), ctx->prev.end(), 0); + ctx->prev.clear(); ctx->cur.clear(); ctx->n_valid = 0; } @@ -400,7 +400,7 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; // apply penalties - const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev.to_vector(); const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; @@ -451,7 +451,9 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar) { - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); + if (!ctx_sampling->prev.empty()) { + ctx_sampling->prev.pop_front(); + } ctx_sampling->prev.push_back(id); if (ctx_sampling->grammar != NULL && apply_grammar) { diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd008..7bbb4bf43d4474 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -8,6 +8,7 @@ #include #include #include +#include // sampler types enum class llama_sampler_type : char { @@ -64,6 +65,106 @@ typedef struct llama_sampling_params { bool use_penalty_prompt_tokens = false; } llama_sampling_params; +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer() {} + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + T & operator[](size_t i) { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + i) % capacity]; + } + + const T & operator[](size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + i) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; + // general sampler context // TODO: move to llama.h struct llama_sampling_context { @@ -78,8 +179,7 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; - // TODO: replace with ring-buffer - std::vector prev; + ring_buffer prev; std::vector cur; size_t n_valid; // Number of correct top tokens with correct probabilities. diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b15fdf1b..85654ea9593260 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -425,7 +425,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(ctx_sampling, ctx, id, true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); embd.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a66cd067..1984f9400a1dee 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -736,7 +736,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); embd.push_back(id);