-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ring buffer to store prev tokens in sampling #8890
Add ring buffer to store prev tokens in sampling #8890
Conversation
kylo5aby
commented
Aug 6, 2024
- I have read the contributing guidelines
- Self-reported review complexity:
- Low
- Medium
- High
@@ -64,6 +65,105 @@ typedef struct llama_sampling_params { | |||
bool use_penalty_prompt_tokens = false; | |||
} llama_sampling_params; | |||
|
|||
template<typename T> | |||
struct ring_buffer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small question: how does this differ from std::queue
or std::deque
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I want to use a fixed capacity buffer to avoid resize or copy overhead, because the size of pre
tokens for sampling is already known.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, std::deque
can't reserve()
like std::vector
. This seems like a valid reason.
Might be worth it to write a (small) comment near ring_buffer
to explain this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved. Thanks the feedback!
9a34948
to
1238001
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the base branch to gg/llama-refactor-sampling
since it's better to merge this change together with the sampling refactoring
examples/infill/infill.cpp
Outdated
@@ -425,7 +425,7 @@ int main(int argc, char ** argv) { | |||
|
|||
llama_sampling_accept(ctx_sampling, ctx, id, true); | |||
|
|||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); | |||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove these logs completely for now - will bring them back after the logger is reimplemented:
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); | |
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); |
examples/main/main.cpp
Outdated
@@ -736,7 +736,7 @@ int main(int argc, char ** argv) { | |||
|
|||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); | |||
|
|||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); | |||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); | |
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str()); |
common/sampling.cpp
Outdated
@@ -400,7 +400,7 @@ static llama_token_data_array llama_sampling_prepare_impl( | |||
llama_token_data_array cur_p = { cur.data(), cur.size(), false }; | |||
|
|||
// apply penalties | |||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; | |||
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev.to_vector(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should think of a way to avoid the to_vector()
due to performance considerations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should think of a way to avoid the
to_vector()
due to performance considerations
I think one way to avoid vector copy is here we can pass penalty_prompt_tokens
vector and start index
to llama_sample_repetition_penalties
, and then traverse penalty_last_n
elements from the vector in it, which will avoid the copy. For example
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
// const llama_token * last_tokens,
const vector<llama_token>& penalty_tokens,
size_t start_index, // .size() - penalty_tokens_used_size or (prev.first +.size() - penalty_tokens_used_size) % prev.capacity if ring buffer
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present);
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the sampling refactoring, the common/sampling.h/.cpp
stuff will be moved to llama-sampling.cpp
and the API call will become simply:
void llama_sampling_repetition_penalties(
struct llama_sampling * ctx,
llama_token_data_array * candidates);
All the penalty related information (together with ring buffer with the previous tokens) will be inside the llama_sampling
object and we can handle it there. So for now, we can just resolve the conflict and merge and later I'll avoid the to_vector()
8603eb2
to
c5734f1
Compare
1238001
to
8830fa1
Compare
8830fa1
to
3b23ea7
Compare
5763d8e
into
ggerganov:gg/llama-refactor-sampling