Skip to content

Commit

Permalink
Merge pull request #30 from wwoodsTM/test-dry-sampler
Browse files Browse the repository at this point in the history
Working implementation of DRY with one key issue I could use help with
  • Loading branch information
l3utterfly authored Aug 8, 2024
2 parents ed6b909 + a18fb2f commit 190898a
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 54 deletions.
4 changes: 2 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
{
const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n);
if (penalty_tokens_used_size) {
llama_sample_dry(&cur_p,
llama_sample_dry(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length,
params.dry_seq_breakers.data(), params.dry_seq_breakers.size());
params.dry_seq_breakers);
}
}

Expand Down
5 changes: 3 additions & 2 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ typedef struct llama_sampling_params {
uint32_t dry_allowed_length = 2;
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)

std::vector<std::string> dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
Expand All @@ -63,9 +65,8 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance

std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

std::vector<llama_token> penalty_prompt_tokens;
std::vector<llama_token> dry_seq_breakers; // sequence breakers for the DRY sampler

bool use_penalty_prompt_tokens = false;
} llama_sampling_params;

Expand Down
33 changes: 23 additions & 10 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,16 +1085,17 @@ extern "C" {
float p,
size_t min_keep);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const llama_token * dry_seq_breakers,
size_t dry_seq_breakers_size);
// /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
// LLAMA_API void llama_sample_dry(
// struct llama_context * ctx,
// llama_token_data_array * candidates,
// const llama_token * last_tokens,
// size_t last_tokens_size,
// float dry_base,
// float dry_multiplier,
// int dry_allowed_length,
// const std::vector<std::string>
// & dry_seq_breakers);

/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(
Expand Down Expand Up @@ -1246,6 +1247,18 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);

/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677
LLAMA_API void llama_sample_dry(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float dry_base,
float dry_multiplier,
int dry_allowed_length,
const std::vector<std::string>
& dry_seq_breakers);

#endif // LLAMA_API_INTERNAL

#endif // LLAMA_H
210 changes: 173 additions & 37 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra
}
}

void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
// skip dry sampler if we don't have a previous token
if (last_tokens_size < 1) return;
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_special,
bool parse_special) {
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_special,
bool parse_special) {
// upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}

std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
}

text.resize(n_chars);

// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return text;
}

std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) {
std::vector<llama_token> tokens = {token};
return llama_detokenize(ctx, tokens, special);
}

// get the last token
auto last_token = last_tokens[last_tokens_size - 1];
// Constants for preventing overflow
const float FLOAT_MAX_LOG = 88.7228391f;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;

// if last token is part of the sequence breakers, skip whole sampler
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) {

void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
if (last_tokens_size < 1) {
return;
}

// create an unordered map of "next tokens" <-> max match length
// Cache for token-to-string conversions
std::unordered_map<llama_token, std::string> token_to_string_cache;
// Store sequence breakers for more efficient lookup
std::unordered_multimap<std::string, std::vector<std::string>> restart_sequences;

auto detokenize_with_cache = [&](llama_token token) -> std::string {
auto it = token_to_string_cache.find(token);
if (it != token_to_string_cache.end()) {
return it->second;
}
std::string token_str = llama_detokenize_single(ctx, token, false);
token_to_string_cache[token] = token_str;
return token_str;
};

// Pre-process dry_seq_breakers
for (const auto& breaker : dry_seq_breakers) {
std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN);
std::vector<llama_token> tokens = llama_tokenize(ctx, breaker_trimmed, false, false);

if (!tokens.empty()) {
std::string head = detokenize_with_cache(tokens[0]);
std::vector<std::string> tail;

for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) {
tail.push_back(detokenize_with_cache(tokens[i]));
}
restart_sequences.emplace(head, tail);
}
}

// Find max repetition length considering restart sequences
int rep_limit = last_tokens_size;

for (size_t i = 0; i < last_tokens_size; ++i) {
size_t ix = last_tokens_size - 1 - i;
std::string token_str = detokenize_with_cache(last_tokens[ix]);

// Check if the token is a potential sequence breaker
auto its = restart_sequences.equal_range(token_str);
if (its.first == restart_sequences.end()) continue;

int longest_match = -1;
// Check all potential sequence breakers starting with this token
for (auto it = its.first; it != its.second; ++it) {
int seq_len = (int)it->second.size();
if (seq_len > longest_match && seq_len <= i) {
bool match = true;
// Check if the following tokens match the sequence breaker
for (size_t offset = 0; offset < seq_len; ++offset) {
if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) {
match = false;
break;
}
}
if (match) {
longest_match = seq_len;
}
}
}

if (longest_match >= 0) {
rep_limit = static_cast<int>(i) - longest_match;
break;
}
}

if (rep_limit <= dry_allowed_length) {
return;
}

// Store max match length for each token
std::unordered_map<llama_token, size_t> match_lengths;

// loop through each previous token (exclude the last token)
// Find repeated sequences
for (size_t i = 0; i < last_tokens_size - 1; ++i) {
// skip if the compare token is not the same as the last token
if (last_tokens[i] != last_token) {
if (last_tokens[i] != last_tokens[last_tokens_size - 1]) {
continue;
}

// get the next token (i + 1 is always less than last_tokens_size)
auto next_token = last_tokens[i + 1];
std::string next_token_str = detokenize_with_cache(next_token);

// if next token is part of the sequence breakers, skip
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) {
// Skip if next token is a sequence breaker
auto its = restart_sequences.equal_range(next_token_str);
if (its.first != restart_sequences.end()) {
continue;
}

// try to extend the match backwards (match length starts at 1 because last token is already matched)
size_t match_length = 1;

// loop through the previous tokens
// Extend match as far as possible
for (;; match_length++) {
// if we have reached the start of our last tokens, break
if (i < match_length) break;
if (i < match_length || match_length > rep_limit) {
break;
}

// compare token starts at our prev index, going backwards by match length
auto compare_token = last_tokens[i - match_length];
std::string compare_token_str = detokenize_with_cache(compare_token);

// head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself
auto head_token = last_tokens[last_tokens_size - 1 - match_length];
std::string head_token_str = detokenize_with_cache(head_token);

// break out of the match if any tokens don't match
if (compare_token != head_token) {
if (compare_token_str != head_token_str) {
break;
}

// if compare token is part of the sequence breakers, break out of the match
if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) {
// Check if we've hit a sequence breaker
its = restart_sequences.equal_range(compare_token_str);
if (its.first != restart_sequences.end()) {
break;
}
}

// Check if the next token exists in the map
// Update max match length for this token
auto it = match_lengths.find(next_token);

if (it == match_lengths.end()) {
// Key does not exist, insert the new value
match_lengths[next_token] = match_length;
} else {
// Key exists, update it with the max of the new value or the existing value
it->second = std::max(it->second, match_length);
}
}

// apply penalties
// Calculate max safe exponent
int max_exponent = 0;
if (dry_base > 1.000001f) {
max_exponent = static_cast<int>(FLOAT_MAX_LOG / log(dry_base));
}

#ifdef DEBUG
LLAMA_LOG_INFO("DRY Sampling parameters:\n");
LLAMA_LOG_INFO(" dry_base: %f\n", dry_base);
LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier);
LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length);
LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent);
LLAMA_LOG_INFO("DRY penalties [");
#endif

// Apply penalties
for (const auto& pair : match_lengths) {
auto next_token = pair.first;
auto match_length = pair.second;

// if the match length is greater than or equal to our allowed length in config, we apply penalities
if (match_length >= (size_t)dry_allowed_length) {

// find our next token in the candidates->data
if (match_length >= static_cast<size_t>(dry_allowed_length)) {
for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].id == next_token) {
// calculate the penalty
float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length);

// apply the dry penalty
int repeat_exp = static_cast<int>(match_length - dry_allowed_length);
if (max_exponent > 0 && repeat_exp > max_exponent) {
repeat_exp = max_exponent;
}
float penalty = dry_multiplier * pow(dry_base, static_cast<float>(repeat_exp));
candidates->data[i].logit -= penalty;

#ifdef DEBUG
LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty);
#endif
break;
}
}
}
}

#ifdef DEBUG
LLAMA_LOG_INFO("]\n");
#endif
}

void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
Expand Down
6 changes: 5 additions & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size);
std::vector<llama_token> llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special);
std::vector<llama_token> llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special);
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special);
std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special);
void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers);
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
Expand Down
4 changes: 2 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18948,8 +18948,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
}

void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) {
llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size);
void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector<std::string> & dry_seq_breakers) {
llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers);
}

void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
Expand Down

0 comments on commit 190898a

Please sign in to comment.