diff --git a/common/common.cpp b/common/common.cpp index ae11650b446a47..b04b82f7644364 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1306,6 +1306,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.token_healing_enabled = true; + auto & th_type = sparams.token_healing_type; + auto & th_n_rollback = sparams.token_healing_n_rollback; + std::string value(argv[i]); + /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + sparams.token_healing_enabled = false; + } + } else { invalid_param = true; } + return true; + } if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; @@ -1503,6 +1525,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); + printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n"); + printf(" Token healing type. (default: 0, disabled)\n"); + printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n"); printf(" --cfg-negative-prompt PROMPT\n"); printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); diff --git a/common/sampling.cpp b/common/sampling.cpp index 7fc2e2158d5c4e..e98df73df0cac5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,109 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +// +// Token healing (external) +// + +std::string llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove, + int * n_removed) { + if (n_removed != nullptr) { + *n_removed = 0; + } + if (tokens.empty()) { + return ""; + } + + const llama_model * model = llama_get_model(ctx_main); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx); + int removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { + break; + } + removed += 1; + prefix = new_prefix; + } + if (removed == 0) { // E.g. if the last token is a special token + return ""; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); + if (removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return ""; + } + // Finalize outputs + if (n_removed != nullptr) { + *n_removed = removed; + } + tokens.resize(n_ctx - removed); + return prefix; +} + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -64,6 +167,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; @@ -122,7 +227,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = sampler_type_to_name_string(sampler_type); @@ -303,8 +408,27 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_type = params.token_healing_type; + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing_enabled && !th_prefix.empty()) { + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } + + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + for (const llama_token token_id: th_candidates) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; @@ -367,4 +491,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } + + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index 655732ad17206f..005c865ca8428f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,13 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +69,10 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; + bool token_healing_enabled = false; + int token_healing_n_rollback = -1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -78,6 +89,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -153,3 +166,18 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +// Roll back `tokens` for constrained generation according to the token healing +// strategy. Returns the prefix for constrained generation. +std::string llama_token_healing_rollback( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove = -1, + int * n_removed = nullptr); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/main/README.md b/examples/main/README.md index ee930f4e79a0d8..db51b148a9d05b 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -259,6 +259,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a Example usage: `--logit-bias 29905-inf` +### Token healing + +- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled). + +Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion. + +- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]. +- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]. +- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated. +- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1]. + +Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035). + ### RNG Seed - `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed). diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 832b51ee086bec..8e8b8355a5c825 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,17 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (sparams.token_healing_enabled && (params.instruct || params.chatml || params.conversation || !params.input_suffix.empty())) { + sparams.token_healing_enabled = false; + LOG("token_healing: disabled due to custom suffix/conversation mode"); + } + std::string token_healing_prefix; + int token_healing_n_removed = 0; + if (!params.interactive_first && sparams.token_healing_enabled) { + token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback, &token_healing_n_removed); + } + // Should not run without any tokens if (embd_inp.empty()) { embd_inp.push_back(llama_token_bos(model)); @@ -283,7 +294,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_n_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -502,6 +513,7 @@ int main(int argc, char ** argv) { int n_consumed = 0; int n_session_consumed = 0; int n_past_guidance = 0; + int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -527,6 +539,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -741,7 +754,16 @@ int main(int argc, char ** argv) { if (input_echo && display) { for (auto id : embd) { const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation); - printf("%s", token_str.c_str()); + + // Suppress printing while generating token healing prefix + if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) { + printf("%s", token_str.substr(n_bytes_to_skip).c_str()); + n_bytes_to_skip = 0; + } else if (n_bytes_to_skip > 0) { + n_bytes_to_skip -= token_str.size(); + } else { + printf("%s", token_str.c_str()); + } if (embd.size() > 1) { input_tokens.push_back(id); @@ -820,6 +842,7 @@ int main(int argc, char ** argv) { } } + token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -903,13 +926,24 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); } + if (sparams.token_healing_enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing_n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing_n_rollback, n_new_tokens); + token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, + max_to_remove, &token_healing_n_removed); + n_bytes_to_skip = token_healing_prefix.size(); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); output_ss << llama_token_to_piece(ctx, token); } - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_n_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -921,6 +955,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_n_removed > 0) { + // Set new prefix after an interaction + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + } } is_interacting = false; }