Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token healing main #219

Merged
merged 8 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true;
}
if (arg == "-th" || arg == "--token-healing") {
CHECK_ARG
std::string value(argv[i]);
invalid_param = !llama_token_healing_parse_params(value, sparams.token_healing);
return true;
}
if (arg == "--override-kv") {
CHECK_ARG
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
Expand Down Expand Up @@ -1484,6 +1490,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted:\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });

options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
"Token healing type. (default: 0, disabled)\n"
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
Expand Down
216 changes: 213 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,181 @@
#include "sampling.h"
#include <random>

//
// Token healing (internal)
//

static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos;
}

static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix)) {
return true;
}
}
return false;
}

static std::vector<llama_token> token_healing_get_candidates(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
// Example: prefix=" world" -> " world", " worldwide", ...
// If `include_partial_prefix`, include also: " w", " wo", ...
std::vector<llama_token> candidates;
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix) ||
(include_partial_prefix && startswith(prefix, token))) {
candidates.push_back(token_id);
}
}
return candidates;
}

static size_t get_max_token_length(const llama_context * ctx_main) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
size_t len = 0;
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
len = std::max(len, token.size());
}
return len;
}

static llama_token_healing_output llama_token_healing_get_prefix(
const llama_context * ctx_main,
const llama_token_healing_type th_type,
const std::vector<llama_token> & tokens,
int max_to_remove) {
if (tokens.size() <= 1) {
return {};
}

const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain

int removed = 0;
std::string prefix;

const llama_model * model = llama_get_model(ctx_main);
auto is_special_token = [&](const llama_token token_id) {
return llama_token_is_control(model, token_id)
|| llama_token_bos (model) == token_id
|| llama_token_eos (model) == token_id
|| llama_token_cls (model) == token_id
|| llama_token_sep (model) == token_id
|| llama_token_pad (model) == token_id
|| llama_token_prefix (model) == token_id
|| llama_token_middle (model) == token_id
|| llama_token_suffix (model) == token_id
|| llama_token_eot (model) == token_id;
};

if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
// The number of bytes to roll back cannot exceed the length of the longest token.
const size_t n_longest_token = get_max_token_length(ctx_main);
size_t len = 0;
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (is_special_token(next_token_id)) {
break;
}
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
if (len + next_token_size > n_longest_token) {
break;
}
len += next_token_size;
removed += 1;
}

while (removed > 0) {
prefix.clear();
for (int i = n_ctx - removed; i < n_ctx; i++) {
prefix += llama_token_to_piece(ctx_main, tokens[i]);
}
if (token_healing_prefix_exists(ctx_main, prefix)) {
break; // Stop on longest valid prefix
}
removed -= 1;
}
} else {
// Roll back tokens a fixed amount and stop early if a special token is encountered.
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (is_special_token(next_token_id)) {
break;
}
removed += 1;
}
for (int i = n_ctx - removed; i < n_ctx; i++) {
prefix += llama_token_to_piece(ctx_main, tokens[i]);
}
}
return {prefix, removed};
}

//
// Token healing (external)
//

llama_token_healing_output llama_token_healing_rollback(
const llama_context * ctx_main,
std::vector<llama_token> & tokens,
llama_token_healing_type th_type,
int max_to_remove) {
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);

// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed);
if (out.n_tokens_removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return {};
}

// Finally, trim prompt tokens
tokens.resize(tokens.size() - out.n_tokens_removed);
return out;
}

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
ctx_sampling->token_healing_prefix = prefix;
}

bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params) {
th_params.enabled = true;
th_params.n_rollback = -1;
/**/ if (params == "0" ) { th_params.enabled = false; }
else if (params == "1" ) { th_params.type = llama_token_healing_type::ROLLBACK_LAST; }
else if (params == "d1") { th_params.type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (params == "d" ) { th_params.type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (params[0] == 'r' ) {
th_params.type = llama_token_healing_type::ROLLBACK_MULTI;
th_params.n_rollback = std::stoi(params.substr(1));
if (th_params.n_rollback <= 0) {
return false;
}
} else {
return false;
}
return true;
}

//
// Sampling
//

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

Expand Down Expand Up @@ -72,6 +247,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->grammar = grammar;
}

ctx->token_healing_prefix.clear();

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_valid = 0;
Expand Down Expand Up @@ -130,7 +307,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
}

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
std::string result = "(Token healing) -> CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
Expand Down Expand Up @@ -393,8 +570,26 @@ static llama_token_data_array llama_sampling_prepare_impl(

cur.resize(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing.enabled && !th_prefix.empty()) {
const bool is_multi_step = params.token_healing.type == llama_token_healing_type::ROLLBACK_MULTI ||
params.token_healing.type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);

LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
}

// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
for (const llama_token token_id : th_candidates) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };
Expand Down Expand Up @@ -457,4 +652,19 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}

if (ctx_sampling->params.token_healing.enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
if (new_token_piece.size() < th_prefix.size()) {
// Shift prefix constraint (for multi step token healing)
th_prefix = th_prefix.substr(new_token_piece.size());
} else {
// Prefix has been generated => no more constrained generation
th_prefix.clear();
LOG("token_healing: done\n");
}
}
}
}
39 changes: 39 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't'
};

enum class llama_token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};

struct llama_token_healing_params {
bool enabled = false;
llama_token_healing_type type = llama_token_healing_type::DYNAMIC_MULTI;
int n_rollback = -1; // number of tokens to roll back
};

// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
Expand Down Expand Up @@ -62,6 +75,8 @@ typedef struct llama_sampling_params {

std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;

llama_token_healing_params token_healing;
} llama_sampling_params;

// general sampler context
Expand All @@ -78,6 +93,8 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

std::string token_healing_prefix; // remaining prefix to constrain sampling

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
Expand Down Expand Up @@ -158,3 +175,25 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

//
// Token healing
//

struct llama_token_healing_output {
std::string prefix;
int n_tokens_removed;
};

// Roll back `tokens` for constrained generation according to the token healing strategy.
// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling.
llama_token_healing_output llama_token_healing_rollback(
const llama_context * ctx_main,
std::vector<llama_token> & tokens,
llama_token_healing_type th_type,
int max_to_remove = -1);

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

// Helper for parsing token healing params from a string.
bool llama_token_healing_parse_params(const std::string & params, llama_token_healing_params & th_params);
13 changes: 13 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a

Example usage: `--logit-bias 29905-inf`

### Token healing

- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).

Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.

- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].

Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).

### RNG Seed

- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
Expand Down
Loading
Loading