diff --git a/common/common.cpp b/common/common.cpp index 3a92d3797492f..354ef08e2a97f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -404,6 +404,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } sparams.penalty_present = std::stof(argv[i]); + } else if (arg == "--penalty-threshold") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.penalty_threshold = std::stof(argv[i]); } else if (arg == "--dynatemp-range") { if (++i >= argc) { invalid_param = true; @@ -976,6 +982,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); + printf(" --penalty-threshold N only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_threshold); printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range); printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent); printf(" --mirostat N use Mirostat sampling.\n"); @@ -1717,6 +1724,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false"); + fprintf(stream, "penalty_threshold: %f # default: 1.0\n", sparams.penalty_threshold); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); diff --git a/common/sampling.cpp b/common/sampling.cpp index 53013138a9eb4..b16c904552197 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -163,15 +163,16 @@ static llama_token llama_sampling_sample_impl( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const float temp = params.temp; - const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; - const float penalty_repeat = params.penalty_repeat; - const float penalty_freq = params.penalty_freq; - const float penalty_present = params.penalty_present; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; + const float temp = params.temp; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const float penalty_threshold = params.penalty_threshold; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; auto & prev = ctx_sampling->prev; auto & cur = ctx_sampling->cur; @@ -215,7 +216,7 @@ static llama_token llama_sampling_sample_impl( llama_sample_repetition_penalties(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present, penalty_threshold); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { diff --git a/common/sampling.h b/common/sampling.h index e1279a8941ce0..14c0e8d4e5028 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -34,6 +34,7 @@ typedef struct llama_sampling_params { float penalty_repeat = 1.10f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled + float penalty_threshold = 1.00f; // 1.0 = disabled int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate diff --git a/examples/main/README.md b/examples/main/README.md index 7f84e42623274..d08cbbfabb0ea 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -182,15 +182,18 @@ Example usage: `--temp 0.5` - `--repeat-penalty N`: Control the repetition of token sequences in the generated text (default: 1.1). - `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size). +- `--penalty-threshold N`: Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value (default: 1.0, 1.0 = disabled). - `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty. The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.1. The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`). +The `penalty-threshold` option disables penalties for very common tokens. This is designed to prevent penalizing tokens that are essential to the structure of the text, such as spaces and punctuation, very common words such as "the", names of participants in chats, brackets and tags in code, etc. For example, a value of 0.1 disables penalties for tokens that make up more than 10% of all tokens in the input. + Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases. -Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl` +Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --penalty-threshold 0.1 --no-penalize-nl` ### Top-K Sampling diff --git a/llama.cpp b/llama.cpp index 6ac9caa957a05..ef162b956b0c1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9634,8 +9634,9 @@ void llama_sample_repetition_penalties( size_t penalty_last_n, float penalty_repeat, float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + float penalty_present, + float penalty_threshold) { + if (penalty_last_n == 0 || penalty_threshold == 0.0f || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } @@ -9656,6 +9657,10 @@ void llama_sample_repetition_penalties( const int count = token_iter->second; + if (float(count) / float(penalty_last_n) > penalty_threshold) { + continue; + } + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (candidates->data[i].logit <= 0) { diff --git a/llama.h b/llama.h index f4ec6ea6394a3..80409d0357bbd 100644 --- a/llama.h +++ b/llama.h @@ -720,6 +720,7 @@ extern "C" { /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// @param penalty_threshold Only apply penalties to tokens whose relative frequency in the penalty context is less than or equal to this value. LLAMA_API void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, @@ -727,7 +728,8 @@ extern "C" { size_t penalty_last_n, float penalty_repeat, float penalty_freq, - float penalty_present); + float penalty_present, + float penalty_threshold); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param logits Logits extracted from the original generation context. diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6374958fee8e6..e32388d4196b1 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -123,7 +123,7 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence + const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence, float penalty_threshold ) { GGML_ASSERT(probs.size() == expected_probs.size()); @@ -138,7 +138,7 @@ static void test_repetition_penalties( llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); + llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, penalty_threshold); llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); @@ -259,13 +259,17 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f, 1.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 1.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 50.0f, 0.0f, 0.0f, 0.5f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f); + test_repetition_penalties({0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f, 0.125f}, {0, 1, 2, 3, 4, 0, 0, 0, 0}, {0.25f, 0.25f, 0.25f, 0.25f, 0, 0, 0, 0}, 50.0f, 0.0f, 0.0f, 0.5f); + + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f, 1.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f, 1.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f, 1.0f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);