Skip to content

Commit

Permalink
Use ring buffer to store prev in sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
kylo5aby committed Aug 12, 2024
1 parent c21a896 commit 1238001
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 8 deletions.
10 changes: 6 additions & 4 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->grammar = grammar;
}

result->prev.resize(params.n_prev);
result->prev = ring_buffer<llama_token>(params.n_prev);

result->n_valid = 0;

Expand Down Expand Up @@ -72,7 +72,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->grammar = grammar;
}

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->prev.clear();
ctx->cur.clear();
ctx->n_valid = 0;
}
Expand Down Expand Up @@ -400,7 +400,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev.to_vector();
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
Expand Down Expand Up @@ -451,7 +451,9 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
if (!ctx_sampling->prev.empty()) {
ctx_sampling->prev.pop_front();
}
ctx_sampling->prev.push_back(id);

if (ctx_sampling->grammar != NULL && apply_grammar) {
Expand Down
104 changes: 102 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <stdexcept>

// sampler types
enum class llama_sampler_type : char {
Expand Down Expand Up @@ -64,6 +65,106 @@ typedef struct llama_sampling_params {
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

// the ring buffer works similarly to std::deque, but with a fixed capacity
template<typename T>
struct ring_buffer {
ring_buffer() {}
ring_buffer(size_t cap) : capacity(cap), data(cap) {}

T & front() {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
return data[first];
}

const T & front() const {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
return data[first];
}

T & back() {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
return data[pos];
}

const T & back() const {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
return data[pos];
}

void push_back(const T & value) {
if (sz == capacity) {
// advance the start when buffer is full
first = (first + 1) % capacity;
} else {
sz++;
}
data[pos] = value;
pos = (pos + 1) % capacity;
}

T pop_front() {
if (sz == 0) {
throw std::runtime_error("ring buffer is empty");
}
T value = data[first];
first = (first + 1) % capacity;
sz--;
return value;
}

T & operator[](size_t i) {
if (i >= sz) {
throw std::runtime_error("ring buffer: index out of bounds");
}
return data[(first + i) % capacity];
}

const T & operator[](size_t i) const {
if (i >= sz) {
throw std::runtime_error("ring buffer: index out of bounds");
}
return data[(first + i) % capacity];
}

std::vector<T> to_vector() const {
std::vector<T> result;
result.reserve(sz);
for (size_t i = 0; i < sz; i++) {
result.push_back(data[(first + i) % capacity]);
}
return result;
}

void clear() {
// here only reset the status of the buffer
sz = 0;
first = 0;
pos = 0;
}

bool empty() const {
return sz == 0;
}

size_t size() const {
return sz;
}

size_t capacity = 0;
size_t sz = 0;
size_t first = 0;
size_t pos = 0;
std::vector<T> data;
};

// general sampler context
// TODO: move to llama.h
struct llama_sampling_context {
Expand All @@ -78,8 +179,7 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur;
size_t n_valid; // Number of correct top tokens with correct probabilities.

Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ int main(int argc, char ** argv) {

llama_sampling_accept(ctx_sampling, ctx, id, true);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());

embd.push_back(id);

Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ int main(int argc, char ** argv) {

llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());

embd.push_back(id);

Expand Down

0 comments on commit 1238001

Please sign in to comment.