Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ring buffer to store prev tokens in sampling #8890

Merged

Conversation

kylo5aby
Copy link
Contributor

@kylo5aby kylo5aby commented Aug 6, 2024

@@ -64,6 +65,105 @@ typedef struct llama_sampling_params {
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

template<typename T>
struct ring_buffer {
Copy link
Collaborator

@compilade compilade Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small question: how does this differ from std::queue or std::deque?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I want to use a fixed capacity buffer to avoid resize or copy overhead, because the size of pre tokens for sampling is already known.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, std::deque can't reserve() like std::vector. This seems like a valid reason.

Might be worth it to write a (small) comment near ring_buffer to explain this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved. Thanks the feedback!

@mofosyne mofosyne added the Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix label Aug 8, 2024
@ggerganov ggerganov changed the base branch from master to gg/llama-refactor-sampling August 12, 2024 08:18
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the base branch to gg/llama-refactor-sampling since it's better to merge this change together with the sampling refactoring

@@ -425,7 +425,7 @@ int main(int argc, char ** argv) {

llama_sampling_accept(ctx_sampling, ctx, id, true);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove these logs completely for now - will bring them back after the logger is reimplemented:

Suggested change
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());

@@ -736,7 +736,7 @@ int main(int argc, char ** argv) {

llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());

@@ -400,7 +400,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
llama_token_data_array cur_p = { cur.data(), cur.size(), false };

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev.to_vector();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should think of a way to avoid the to_vector() due to performance considerations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should think of a way to avoid the to_vector() due to performance considerations

I think one way to avoid vector copy is here we can pass penalty_prompt_tokens vector and start index to llama_sample_repetition_penalties, and then traverse penalty_last_n elements from the vector in it, which will avoid the copy. For example

void llama_sample_repetition_penalties(
            struct llama_context * ctx,
          llama_token_data_array * candidates,
                         // const llama_token * last_tokens,
                          const vector<llama_token>& penalty_tokens,
                          size_t start_index, // .size() - penalty_tokens_used_size or (prev.first +.size() - penalty_tokens_used_size) % prev.capacity if ring buffer
                          size_t   penalty_last_n,
                           float   penalty_repeat,
                           float   penalty_freq,
                           float   penalty_present);

what do you think?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the sampling refactoring, the common/sampling.h/.cpp stuff will be moved to llama-sampling.cpp and the API call will become simply:

void llama_sampling_repetition_penalties(
           struct llama_sampling * ctx,
          llama_token_data_array * candidates);

All the penalty related information (together with ring buffer with the previous tokens) will be inside the llama_sampling object and we can handle it there. So for now, we can just resolve the conflict and merge and later I'll avoid the to_vector()

@ggerganov ggerganov force-pushed the gg/llama-refactor-sampling branch from 8603eb2 to c5734f1 Compare August 12, 2024 12:43
@ggerganov ggerganov merged commit 5763d8e into ggerganov:gg/llama-refactor-sampling Aug 13, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples Review Complexity : Low Trivial changes to code that most beginner devs (or those who want a break) can tackle. e.g. UI fix
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants