Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Penalty threshold: A mechanism for improving repetition penalties #5561

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--penalty-threshold") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.penalty_threshold = std::stof(argv[i]);
} else if (arg == "--dynatemp-range") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -976,6 +982,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --penalty-threshold N only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_threshold);
printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range);
printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent);
printf(" --mirostat N use Mirostat sampling.\n");
Expand Down Expand Up @@ -1717,6 +1724,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
fprintf(stream, "penalty_threshold: %f # default: 1.0\n", sparams.penalty_threshold);
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
Expand Down
21 changes: 11 additions & 10 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,16 @@ static llama_token llama_sampling_sample_impl(

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const float penalty_threshold = params.penalty_threshold;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
Expand Down Expand Up @@ -215,7 +216,7 @@ static llama_token llama_sampling_sample_impl(

llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present, penalty_threshold);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
Expand Down
1 change: 1 addition & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ typedef struct llama_sampling_params {
float penalty_repeat = 1.10f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float penalty_threshold = 1.00f; // 1.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
Expand Down
5 changes: 4 additions & 1 deletion examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,18 @@ Example usage: `--temp 0.5`

- `--repeat-penalty N`: Control the repetition of token sequences in the generated text (default: 1.1).
- `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
- `--penalty-threshold N`: Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: 1.0, 1.0 = disabled).
- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.

The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1.

The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).

The `penalty-threshold` option disables penalties for very common tokens. This is designed to prevent penalizing tokens that are essential to the structure of the text, such as spaces and punctuation, very common words such as "the", names of participants in chats, brackets and tags in code, etc. For example, a value of 0.1 disables penalties for tokens that make up more than 10% of all tokens in the input.

Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.

Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --penalty-threshold 0.1 --no-penalize-nl`

### Top-K Sampling

Expand Down
9 changes: 7 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9634,8 +9634,9 @@ void llama_sample_repetition_penalties(
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
float penalty_present,
float penalty_threshold) {
if (penalty_last_n == 0 || penalty_threshold == 0.0f || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return;
}

Expand All @@ -9656,6 +9657,10 @@ void llama_sample_repetition_penalties(

const int count = token_iter->second;

if (float(count) / float(penalty_last_n) > penalty_threshold) {
continue;
}

// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) {
Expand Down
4 changes: 3 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,14 +720,16 @@ extern "C" {

/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
/// @param penalty_threshold Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value.
LLAMA_API void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present);
float penalty_present,
float penalty_threshold);

/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param logits Logits extracted from the original generation context.
Expand Down
20 changes: 12 additions & 8 deletions tests/test-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo

static void test_repetition_penalties(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_threshold
) {
GGML_ASSERT(probs.size() == expected_probs.size());

Expand All @@ -138,7 +138,7 @@ static void test_repetition_penalties(
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_threshold);
llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p);

Expand Down Expand Up @@ -259,13 +259,17 @@ int main(void) {
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);

test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand how these tests are intended to work, even before my changes. The only prior token here is 0, so I would expect the resulting probability vector to be {0, 0.25f, 0.25f, 0.25f, 0.25f}, that is, the probability at index 0 to be penalized. Please help me understand what is going on here so I can make sure the code actually works correctly.

test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 1.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f);

test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 50.0f, 0.0f, 0.0f, 0.5f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f);
test_repetition_penalties({0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f}, {0, 1, 2, 3, 4, 0, 0, 0, 0}, {0.25f, 0.25f, 0.25f, 0.25f, 0, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f);

test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 1.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 1.0f);
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 1.0f);

test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
Expand Down