diff --git a/common/sampling.cpp b/common/sampling.cpp index e05cb754c0d5f..d369f6c4a63e7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -40,7 +40,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data()); } - result->prev.resize(params.n_prev); + result->prev = ring_buffer(params.n_prev); result->n_valid = 0; @@ -56,7 +56,7 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { void llama_sampling_reset(llama_sampling_context * ctx) { llama_sampling_reset(ctx->smpl); - std::fill(ctx->prev.begin(), ctx->prev.end(), 0); + ctx->prev.clear(); ctx->cur.clear(); ctx->n_valid = 0; } @@ -384,7 +384,7 @@ static llama_token_data_array llama_sampling_prepare_impl( llama_token_data_array cur_p = { cur.data(), cur.size(), false }; // apply penalties - const auto & penalty_tokens = prev; + const auto & penalty_tokens = prev.to_vector(); const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; @@ -434,7 +434,9 @@ void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, llama_token id, bool apply_grammar) { - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); + if (!ctx_sampling->prev.empty()) { + ctx_sampling->prev.pop_front(); + } ctx_sampling->prev.push_back(id); if (apply_grammar) { diff --git a/common/sampling.h b/common/sampling.h index 74020d52b0171..3e36e90ec6208 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,6 +4,7 @@ #include #include +#include // sampler types enum class llama_sampler_type : char { @@ -58,6 +59,106 @@ typedef struct gpt_sampling_params { std::vector logit_bias; // logit biases to apply } gpt_sampling_params; +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer() {} + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + T & operator[](size_t i) { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + i) % capacity]; + } + + const T & operator[](size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + i) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; + // general sampler context // TODO: move to llama.h struct llama_sampling_context { @@ -69,8 +170,7 @@ struct llama_sampling_context { llama_sampling * smpl; - // TODO: replace with ring-buffer - std::vector prev; + ring_buffer prev; std::vector cur; size_t n_valid; // Number of correct top tokens with correct probabilities. diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index b6bff3f9e78ad..7d6c84f99bfe1 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -421,7 +421,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(ctx_sampling, id, true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); embd.push_back(id); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3619157cfb476..71bf0a5183385 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -767,7 +767,7 @@ int main(int argc, char ** argv) { llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); embd.push_back(id);