Skip to content

Commit

Permalink
main : add token healing
Browse files Browse the repository at this point in the history
  • Loading branch information
mare5x committed May 22, 2024
1 parent b18532a commit 8c44086
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 6 deletions.
25 changes: 25 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
return true;
}
if (arg == "-th" || arg == "--token-healing") {
if (++i >= argc) {
invalid_param = true;
return true;
}
sparams.token_healing_enabled = true;
auto & th_type = sparams.token_healing_type;
auto & th_n_rollback = sparams.token_healing_n_rollback;
std::string value(argv[i]);
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
else if (value[0] == 'r' ) {
th_type = llama_token_healing_type::ROLLBACK_MULTI;
th_n_rollback = std::stoi(value.substr(1));
if (th_n_rollback <= 0) {
sparams.token_healing_enabled = false;
}
} else { invalid_param = true; }
return true;
}
if (arg == "--override-kv") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1503,6 +1525,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -j SCHEMA, --json-schema SCHEMA\n");
printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n");
printf(" Token healing type. (default: 0, disabled)\n");
printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n");
printf(" --cfg-negative-prompt PROMPT\n");
printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n");
Expand Down
145 changes: 142 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,109 @@
#include "sampling.h"
#include <random>

//
// Token healing (internal)
//

static bool startswith(const std::string & str, const std::string & prefix) {
return str.rfind(prefix, 0) != std::string::npos;
}

static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
return true;
}
}
return false;
}

static std::vector<llama_token> token_healing_find_prefix(
const llama_context * ctx_main,
const std::string & prefix,
const bool include_partial_prefix) {
// Example: prefix=" world" -> " world", " worldwide", ...
// If `include_partial_prefix`, include also: " w", " wo", ...
std::vector<llama_token> candidates;
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
std::string token = llama_token_to_piece(ctx_main, token_id);
if (startswith(token, prefix) ||
(include_partial_prefix && startswith(prefix, token))) {
candidates.push_back(token_id);
}
}
return candidates;
}

//
// Token healing (external)
//

std::string llama_token_healing_rollback(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove,
int * n_removed) {
if (n_removed != nullptr) {
*n_removed = 0;
}
if (tokens.empty()) {
return "";
}

const llama_model * model = llama_get_model(ctx_main);
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
const int n_ctx = tokens.size();
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx);
int removed = 0;
std::string prefix;
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
// and stop early if a special token is encountered
while (removed < max_to_remove) {
const llama_token next_token_id = tokens[n_ctx - removed - 1];
if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) {
// Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize)
break;
}
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
break;
}
removed += 1;
prefix = new_prefix;
}
if (removed == 0) { // E.g. if the last token is a special token
return "";
}
// If constrained decoding would give back the original prompt, there is no need to modify the context
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
if (removed == 1 && candidates.size() == 1) {
LOG("token_healing: nothing to heal\n");
return "";
}
// Finalize outputs
if (n_removed != nullptr) {
*n_removed = removed;
}
tokens.resize(n_ctx - removed);
return prefix;
}

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
ctx_sampling->token_healing_prefix = prefix;
}

//
// Sampling
//

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

Expand Down Expand Up @@ -64,6 +167,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}

ctx->token_healing_prefix.clear();

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
ctx->n_valid = 0;
Expand Down Expand Up @@ -122,7 +227,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
}

std::string llama_sampling_order_print(const llama_sampling_params & params) {
std::string result = "CFG -> Penalties ";
std::string result = "(Token healing) -> CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = sampler_type_to_name_string(sampler_type);
Expand Down Expand Up @@ -303,8 +408,27 @@ static llama_token_data_array llama_sampling_prepare_impl(

cur.clear();

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
// Constrain tokens based on the remaining token healing prefix (if any)
const auto & th_type = params.token_healing_type;
const auto & th_prefix = ctx_sampling->token_healing_prefix;
if (params.token_healing_enabled && !th_prefix.empty()) {
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
th_type == llama_token_healing_type::DYNAMIC_MULTI;
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);

LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
for (const llama_token token_id : th_candidates) {
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
}

// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
for (const llama_token token_id: th_candidates) {
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };
Expand Down Expand Up @@ -367,4 +491,19 @@ void llama_sampling_accept(
if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}

if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
std::string & th_prefix = ctx_sampling->token_healing_prefix;
if (!th_prefix.empty()) {
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
if (new_token_piece.size() < th_prefix.size()) {
// Shift prefix constraint (for multi step token healing)
th_prefix = th_prefix.substr(new_token_piece.size());
} else {
// Prefix has been generated => no more constrained generation
th_prefix.clear();
LOG("token_healing: done\n");
}
}
}
}
28 changes: 28 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
TEMPERATURE = 't'
};

enum class llama_token_healing_type : uint8_t {
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
};

// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
Expand Down Expand Up @@ -62,6 +69,10 @@ typedef struct llama_sampling_params {

std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;

llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
bool token_healing_enabled = false;
int token_healing_n_rollback = -1; // number of tokens to roll back
} llama_sampling_params;

// general sampler context
Expand All @@ -78,6 +89,8 @@ struct llama_sampling_context {
// internal
grammar_parser::parse_state parsed_grammar;

std::string token_healing_prefix;

// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
Expand Down Expand Up @@ -153,3 +166,18 @@ void llama_sampling_accept(
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);

//
// Token healing
//

// Roll back `tokens` for constrained generation according to the token healing
// strategy. Returns the prefix for constrained generation.
std::string llama_token_healing_rollback(
const llama_context * ctx_main,
llama_token_healing_type th_type,
std::vector<llama_token> & tokens,
int max_to_remove = -1,
int * n_removed = nullptr);

void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);
13 changes: 13 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a

Example usage: `--logit-bias 29905-inf`

### Token healing

- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).

Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.

- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].

Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).

### RNG Seed

- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
Expand Down
44 changes: 41 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,17 @@ int main(int argc, char ** argv) {
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());

if (sparams.token_healing_enabled && (params.instruct || params.chatml || params.conversation || !params.input_suffix.empty())) {
sparams.token_healing_enabled = false;
LOG("token_healing: disabled due to custom suffix/conversation mode");
}
std::string token_healing_prefix;
int token_healing_n_removed = 0;
if (!params.interactive_first && sparams.token_healing_enabled) {
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
sparams.token_healing_n_rollback, &token_healing_n_removed);
}

// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model));
Expand All @@ -283,7 +294,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());

original_prompt_len = original_inp.size();
original_prompt_len = original_inp.size() - token_healing_n_removed;
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
LOG("guidance_offset: %s", log_tostr(guidance_offset));
Expand Down Expand Up @@ -502,6 +513,7 @@ int main(int argc, char ** argv) {
int n_consumed = 0;
int n_session_consumed = 0;
int n_past_guidance = 0;
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix

std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
Expand All @@ -527,6 +539,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down Expand Up @@ -741,7 +754,16 @@ int main(int argc, char ** argv) {
if (input_echo && display) {
for (auto id : embd) {
const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation);
printf("%s", token_str.c_str());

// Suppress printing while generating token healing prefix
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
printf("%s", token_str.substr(n_bytes_to_skip).c_str());
n_bytes_to_skip = 0;
} else if (n_bytes_to_skip > 0) {
n_bytes_to_skip -= token_str.size();
} else {
printf("%s", token_str.c_str());
}

if (embd.size() > 1) {
input_tokens.push_back(id);
Expand Down Expand Up @@ -820,6 +842,7 @@ int main(int argc, char ** argv) {
}
}

token_healing_n_removed = 0;
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");

Expand Down Expand Up @@ -903,13 +926,24 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
}

if (sparams.token_healing_enabled) {
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
const int n_new_tokens = embd_inp.size() - original_size;
const int max_to_remove = sparams.token_healing_n_rollback < 0
? n_new_tokens
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
max_to_remove, &token_healing_n_removed);
n_bytes_to_skip = token_healing_prefix.size();
}

for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
output_tokens.push_back(token);
output_ss << llama_token_to_piece(ctx, token);
}

n_remain -= line_inp.size();
n_remain -= line_inp.size() + token_healing_n_removed;
LOG("n_remain: %d\n", n_remain);
} else {
LOG("empty line, passing control back\n");
Expand All @@ -921,6 +955,10 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling);
if (token_healing_n_removed > 0) {
// Set new prefix after an interaction
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
}
}
is_interacting = false;
}
Expand Down

0 comments on commit 8c44086

Please sign in to comment.