diff --git a/common/common.cpp b/common/common.cpp index 243b88abf1aab..b75cfdf952365 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1288,6 +1288,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } + if (arg == "-th" || arg == "--token-healing") { + if (++i >= argc) { + invalid_param = true; + return true; + } + sparams.token_healing_enabled = true; + auto & th_type = sparams.token_healing_type; + auto & th_n_rollback = sparams.token_healing_n_rollback; + std::string value(argv[i]); + /**/ if (value == "0" ) { sparams.token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + sparams.token_healing_enabled = false; + } + } else { invalid_param = true; } + return true; + } if (arg == "--override-kv") { if (++i >= argc) { invalid_param = true; @@ -1480,6 +1502,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); + printf(" -th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}\n"); + printf(" Token healing type. (default: 0, disabled)\n"); + printf(" 1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens\n"); printf(" --cfg-negative-prompt PROMPT\n"); printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); diff --git a/common/sampling.cpp b/common/sampling.cpp index cc83600d9926e..7e7bf5ea1d144 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,109 @@ #include "sampling.h" #include +// +// Token healing (internal) +// + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const llama_context * ctx_main, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + std::string token = llama_token_to_piece(ctx_main, token_id); + if (startswith(token, prefix) || + (include_partial_prefix && startswith(prefix, token))) { + candidates.push_back(token_id); + } + } + return candidates; +} + +// +// Token healing (external) +// + +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove, + int * n_removed) { + if (n_removed != nullptr) { + *n_removed = 0; + } + if (tokens.empty()) { + return ""; + } + + const llama_model * model = llama_get_model(ctx_main); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens.size(); + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; + max_to_remove = max_to_remove < 0 ? n_ctx : std::min(max_to_remove, n_ctx); + int removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered + while (removed < max_to_remove) { + const llama_token next_token_id = tokens[n_ctx - removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { + break; + } + removed += 1; + prefix = new_prefix; + } + if (removed == 0) { // E.g. if the last token is a special token + return ""; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + const std::vector candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); + if (removed == 1 && candidates.size() == 1) { + LOG("token_healing: nothing to heal\n"); + return ""; + } + // Finalize outputs + if (n_removed != nullptr) { + *n_removed = removed; + } + tokens.resize(n_ctx - removed); + return prefix; +} + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { + ctx_sampling->token_healing_prefix = prefix; +} + +// +// Sampling +// + struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -62,6 +165,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) { grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); } + ctx->token_healing_prefix.clear(); + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); } @@ -119,7 +224,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; + std::string result = "(Token healing) -> CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { const auto sampler_type_name = sampler_type_to_name_string(sampler_type); @@ -297,8 +402,27 @@ static llama_token_data_array llama_sampling_prepare_impl( cur.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + // Constrain tokens based on the remaining token healing prefix (if any) + const auto & th_type = params.token_healing_type; + const auto & th_prefix = ctx_sampling->token_healing_prefix; + if (params.token_healing_enabled && !th_prefix.empty()) { + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || + th_type == llama_token_healing_type::DYNAMIC_MULTI; + std::vector th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); + + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); + for (const llama_token token_id : th_candidates) { + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); + } + + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + for (const llama_token token_id: th_candidates) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } } llama_token_data_array cur_p = { cur.data(), cur.size(), false }; @@ -361,4 +485,19 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } + + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { + std::string & th_prefix = ctx_sampling->token_healing_prefix; + if (!th_prefix.empty()) { + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); + if (new_token_piece.size() < th_prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_prefix = th_prefix.substr(new_token_piece.size()); + } else { + // Prefix has been generated => no more constrained generation + th_prefix.clear(); + LOG("token_healing: done\n"); + } + } + } } diff --git a/common/sampling.h b/common/sampling.h index cf7081e3674f1..2aa7bc2bdd8b1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -19,6 +19,13 @@ enum class llama_sampler_type : char { TEMPERATURE = 't' }; +enum class llama_token_healing_type : uint8_t { + ROLLBACK_LAST, // roll back last token with a single constrained decoding step + ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps + DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step + DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps +}; + // sampling parameters typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember @@ -62,6 +69,10 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; + + llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST; + bool token_healing_enabled = false; + int token_healing_n_rollback = -1; // number of tokens to roll back } llama_sampling_params; // general sampler context @@ -78,6 +89,8 @@ struct llama_sampling_context { // internal grammar_parser::parse_state parsed_grammar; + std::string token_healing_prefix; + // TODO: replace with ring-buffer std::vector prev; std::vector cur; @@ -152,3 +165,18 @@ void llama_sampling_accept( struct llama_context * ctx_main, llama_token id, bool apply_grammar); + +// +// Token healing +// + +// Roll back `tokens` for constrained generation according to the token healing +// strategy. Returns the prefix for constrained generation. +std::string llama_token_healing_prepare( + const llama_context * ctx_main, + llama_token_healing_type th_type, + std::vector & tokens, + int max_to_remove = -1, + int * n_removed = nullptr); + +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f421769cc2f0a..11cbd8f61dc07 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -38,6 +38,7 @@ else() add_subdirectory(retrieval) add_subdirectory(save-load-state) add_subdirectory(simple) + add_subdirectory(simple-token-healing) add_subdirectory(passkey) add_subdirectory(speculative) add_subdirectory(lookahead) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 5c693657c8993..70834b01a8eca 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,17 @@ int main(int argc, char ** argv) { LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + if (sparams.token_healing_enabled && (params.instruct || params.chatml || !params.input_suffix.empty())) { + sparams.token_healing_enabled = false; + LOG("token_healing: disabled due to custom suffix"); + } + std::string token_healing_prefix; + int token_healing_n_removed = 0; + if (!params.interactive_first && sparams.token_healing_enabled) { + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + sparams.token_healing_n_rollback, &token_healing_n_removed); + } + // Should not run without any tokens if (embd_inp.empty()) { embd_inp.push_back(llama_token_bos(model)); @@ -283,7 +294,7 @@ int main(int argc, char ** argv) { std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - original_prompt_len = original_inp.size(); + original_prompt_len = original_inp.size() - token_healing_n_removed; guidance_offset = (int)guidance_inp.size() - original_prompt_len; LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); LOG("guidance_offset: %s", log_tostr(guidance_offset)); @@ -499,6 +510,7 @@ int main(int argc, char ** argv) { int n_consumed = 0; int n_session_consumed = 0; int n_past_guidance = 0; + int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -520,6 +532,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -734,7 +747,16 @@ int main(int argc, char ** argv) { if (input_echo && display) { for (auto id : embd) { const std::string token_str = llama_token_to_piece(ctx, id); - printf("%s", token_str.c_str()); + + // Suppress printing while generating token healing prefix (only for interactive mode; kinda hacky...) + if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) { + printf("%s", token_str.substr(n_bytes_to_skip).c_str()); + n_bytes_to_skip = 0; + } else if (n_bytes_to_skip > 0) { + n_bytes_to_skip -= token_str.size(); + } else { + printf("%s", token_str.c_str()); + } if (embd.size() > 1) { input_tokens.push_back(id); @@ -813,6 +835,7 @@ int main(int argc, char ** argv) { } } + token_healing_n_removed = 0; if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); @@ -896,13 +919,24 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); } + if (sparams.token_healing_enabled) { + // Limit token healing rollback to new tokens only (otherwise would need to shift everything) + const int n_new_tokens = embd_inp.size() - original_size; + const int max_to_remove = sparams.token_healing_n_rollback < 0 + ? n_new_tokens + : std::min(sparams.token_healing_n_rollback, n_new_tokens); + token_healing_prefix = llama_token_healing_prepare(ctx, sparams.token_healing_type, embd_inp, + max_to_remove, &token_healing_n_removed); + n_bytes_to_skip = token_healing_prefix.size(); + } + for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); output_ss << llama_token_to_piece(ctx, token); } - n_remain -= line_inp.size(); + n_remain -= line_inp.size() + token_healing_n_removed; LOG("n_remain: %d\n", n_remain); } else { LOG("empty line, passing control back\n"); @@ -914,6 +948,10 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { llama_sampling_reset(ctx_sampling); + if (token_healing_n_removed > 0) { + // Set new prefix after an interaction + llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix); + } } is_interacting = false; } diff --git a/examples/simple-token-healing/CMakeLists.txt b/examples/simple-token-healing/CMakeLists.txt new file mode 100644 index 0000000000000..1d41611dd8e76 --- /dev/null +++ b/examples/simple-token-healing/CMakeLists.txt @@ -0,0 +1,11 @@ +set(TARGET simple-token-healing) +add_executable(${TARGET} simple-token-healing.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET simple-token-healing-1) +add_executable(${TARGET} simple-token-healing-1.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/simple-token-healing/README.md b/examples/simple-token-healing/README.md new file mode 100644 index 0000000000000..533c118bd8c48 --- /dev/null +++ b/examples/simple-token-healing/README.md @@ -0,0 +1,105 @@ +# llama.cpp/example/simple-token-healing + +This example extends [simple](../simple/README.md) with token healing (aka. token alignment). + +`usage: ./simple-token-healing MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]` + +## Examples +`0`: Without token healing (same as running `./simple ...`): +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 0 +... +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Helping the customer') +... +``` + +`1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2]: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hel" 1 +... +token_healing: prefix = 'Hel' (1 tokens) + [ 12621] 'Hel' + [ 15496] 'Hello' + [ 22087] 'Help' + [ 28254] 'Hell' + [ 47429] 'Helper' + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, World!') +... +``` + +`d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2]: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d1 +... +token_healing: prefix = ' worl' (2 tokens) + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 29081] ' worldview' + [ 43249] ' worldly' + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, world!') +... +``` + +`d`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix but allow multiple decoding steps: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" d +... +token_healing: prefix = ' worl' (2 tokens) + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello, +token_healing: prefix = ' worl' + [ 220] ' ' + [ 266] ' w' + [ 476] ' wor' + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 24486] ' wo' + [ 29081] ' worldview' + [ 43249] ' worldly' + world!') +... +``` + +`r[N]`: Roll back `N` tokens and constrain the decoding to the bytes of those tokens (multiple decoding steps) [1]. +The paper [1] recommends `N=3`: +```bash +./simple-token-healing ./models/phi-2/ggml-model-q4_0.gguf "print('Hello, worl" r3 +... +token_healing: prefix = ', worl' (3 tokens) + +main: n_len = 32, n_ctx = 2048, n_kv_req = 32 + +print('Hello +token_healing: prefix = ', worl' + [ 11] ',' +, +token_healing: prefix = ' worl' + [ 220] ' ' + [ 266] ' w' + [ 476] ' wor' + [ 995] ' world' + [ 8688] ' worldwide' + [ 11621] ' worlds' + [ 24486] ' wo' + [ 29081] ' worldview' + [ 43249] ' worldly' + world!') +... +``` + +## Sources +- [0] https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb +- [1] https://arxiv.org/abs/2403.08688 +- [2] https://arxiv.org/abs/2402.01035 diff --git a/examples/simple-token-healing/simple-token-healing-1.cpp b/examples/simple-token-healing/simple-token-healing-1.cpp new file mode 100644 index 0000000000000..6febeb38f6ff6 --- /dev/null +++ b/examples/simple-token-healing/simple-token-healing-1.cpp @@ -0,0 +1,232 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +static std::vector heal_last_token(const llama_context * ctx, const std::vector & tokens_list) { + const llama_token last_token_id = tokens_list.back(); + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + + // Don't roll back e.g. <|endoftext|> (set parse_special=true in llama_tokenize) + if (llama_token_get_type(model, last_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + return {}; + } + + const std::string last_piece = llama_token_to_piece(ctx, last_token_id); + fprintf(stderr, "token_healing: prefix = '%s'\n", last_piece.c_str()); + + fprintf(stderr, "token_healing: candidates:\n"); + fprintf(stderr, " [%6d] '%s'\n", last_token_id, last_piece.c_str()); + std::vector candidates = { last_token_id }; + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { + if (token_id == last_token_id) { + continue; + } + std::string token_piece = llama_token_to_piece(ctx, token_id); + if (token_piece.rfind(last_piece, 0) != std::string::npos) { + candidates.push_back(token_id); + fprintf(stderr, " [%6d] '%s'\n", token_id, token_piece.c_str()); + } + } + if (candidates.size() == 1) { + // No healing necessary if the last token is the only candidate. + return {}; + } + return candidates; +} + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); + return 1 ; + } + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // total length of the sequence including the prompt + const int n_len = 32; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + // Roll back the last token and constrain tokens to generate in the next step to match the removed last token. + std::vector token_healing_candidates = heal_last_token(ctx, tokens_list); + if (!token_healing_candidates.empty()) { + tokens_list.pop_back(); + } + if (tokens_list.empty()) { + // If we remove the first token, llama_decode would crash with an empty sequence, so add bos. + tokens_list.emplace_back(llama_token_bos(model)); + } + + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + fprintf(stderr, "\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + + llama_batch batch = llama_batch_init(512, 0, 1); + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + if (n_decode == 0 && !token_healing_candidates.empty()) { + for (const llama_token token_id : token_healing_candidates) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } + + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); + + n_decode += 1; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} diff --git a/examples/simple-token-healing/simple-token-healing.cpp b/examples/simple-token-healing/simple-token-healing.cpp new file mode 100644 index 0000000000000..05091b9c33c62 --- /dev/null +++ b/examples/simple-token-healing/simple-token-healing.cpp @@ -0,0 +1,358 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include + +#define TH_VERBOSE // print token healing candidates + +struct token_healing_context { + std::string prefix; // remaining prefix to generate (the input prompt's suffix) + + std::vector vocab; // map token id to token piece + // TODO consider using a prefix tree +}; + +static bool startswith(const std::string & str, const std::string & prefix) { + return str.rfind(prefix, 0) != std::string::npos; +} + +static bool token_healing_prefix_exists(const token_healing_context * th_ctx, const std::string & prefix) { + for (const std::string & token : th_ctx->vocab) { + if (startswith(token, prefix)) { + return true; + } + } + return false; +} + +static std::vector token_healing_find_prefix( + const token_healing_context * th_ctx, + const std::string & prefix, + const bool include_partial_prefix) { + // Example: prefix=" world" -> " world", " worldwide", ... + // If `include_partial_prefix`, include also: " w", " wo", ... + std::vector candidates; + const auto & vocab = th_ctx->vocab; + for (size_t token_id = 0; token_id < vocab.size(); ++token_id) { + if (startswith(vocab[token_id], prefix) || + (include_partial_prefix && startswith(prefix, vocab[token_id]))) { + candidates.push_back((llama_token)token_id); + } + } + return candidates; +} + +static token_healing_context * token_healing_init(const llama_context * ctx) { + auto * th_ctx = new token_healing_context; + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + std::vector & vocab = th_ctx->vocab; + vocab.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + vocab.emplace_back(llama_token_to_piece(ctx, token_id, true)); + } + return th_ctx; +} + +static void token_healing_free(token_healing_context * th_ctx) { + delete th_ctx; +} + +static int token_healing_heal( + const llama_context * ctx, + std::vector & tokens_list, + const llama_token_healing_type th_type, + token_healing_context * th_ctx, + int n_rollback = 1) { + if (tokens_list.empty()) { + return 0; + } + const llama_model * model = llama_get_model(ctx); + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; + const int n_ctx = tokens_list.size(); + const int max_to_remove = is_dynamic ? n_ctx : std::min(n_rollback, n_ctx); + int n_removed = 0; + std::string prefix; + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt + // and stop early if a special token is encountered + while (n_removed < max_to_remove) { + const llama_token next_token_id = tokens_list[n_ctx - n_removed - 1]; + if (llama_token_get_type(model, next_token_id) != LLAMA_TOKEN_TYPE_NORMAL) { + // Don't roll back e.g. <|endoftext|> (if parse_special=true in llama_tokenize) + break; + } + std::string new_prefix = th_ctx->vocab[next_token_id] + prefix; + if (is_dynamic && !token_healing_prefix_exists(th_ctx, new_prefix)) { + break; + } + n_removed += 1; + prefix = new_prefix; + } + th_ctx->prefix = prefix; + + if (n_removed == 0) { + return 0; + } + // If constrained decoding would give back the original prompt, there is no need to modify the context + const bool is_multi_decoding = th_type == llama_token_healing_type::DYNAMIC_MULTI || th_type == llama_token_healing_type::ROLLBACK_MULTI; + const std::vector candidates = token_healing_find_prefix(th_ctx, prefix, is_multi_decoding); + fprintf(stderr, "token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), n_removed); + if (n_removed == 1 && candidates.size() == 1) { + fprintf(stderr, "token_healing: nothing to heal\n"); + return 0; + } +#ifdef TH_VERBOSE + if (!is_multi_decoding) { + // Other healing types get printed during decoding + for (const llama_token token_id : candidates) { + fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); + } + } +#endif + tokens_list.resize(n_ctx - n_removed); + if (tokens_list.empty()) { + // If the first token was removed, llama_decode would crash with an empty sequence, so add bos. + tokens_list.emplace_back(llama_token_bos(model)); + } + return n_removed; +} + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); + return 1; + } + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + params.prompt = argv[2]; + } + + bool token_healing_enabled = true; + auto th_type = llama_token_healing_type::DYNAMIC_MULTI; + int th_n_rollback = 1; + if (argc >= 4) { + std::string value(argv[3]); + /**/ if (value == "0" ) { token_healing_enabled = false; } + else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; th_n_rollback = 1; } + else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; } + else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; } + else if (value[0] == 'r' ) { + th_type = llama_token_healing_type::ROLLBACK_MULTI; + th_n_rollback = std::stoi(value.substr(1)); + if (th_n_rollback <= 0) { + token_healing_enabled = false; + } + } else { + printf("usage: %s MODEL_PATH [PROMPT] [TOKEN_HEALING 0|1|d1|d|r[N]]\n" , argv[0]); + return 1; + } + } + + if (params.prompt.empty()) { + params.prompt = "Hello my name is"; + } + + // total length of the sequence including the prompt + const int n_len = 32; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + token_healing_context * th_ctx = nullptr; + if (token_healing_enabled) { + th_ctx = token_healing_init(ctx); + int th_n_tokens_removed = token_healing_heal(ctx, tokens_list, th_type, th_ctx, th_n_rollback); + if (th_n_tokens_removed == 0) { + token_healing_enabled = false; + } + } + + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_len or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + fprintf(stderr, "\n"); + + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + + llama_batch batch = llama_batch_init(512, 0, 1); + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + auto n_vocab = llama_n_vocab(model); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + + std::vector candidates; + candidates.reserve(n_vocab); + + if (token_healing_enabled) { + // Constrain tokens based on the remaining token healing prefix + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf + std::vector th_candidates; + if (th_type == llama_token_healing_type::ROLLBACK_LAST || th_type == llama_token_healing_type::DYNAMIC_ONCE) { + th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, false); + } else { + th_candidates = token_healing_find_prefix(th_ctx, th_ctx->prefix, true); +#ifdef TH_VERBOSE + fprintf(stderr, "\ntoken_healing: prefix = '%s'\n", th_ctx->prefix.c_str()); + for (const llama_token token_id : th_candidates) { + fprintf(stderr, " [%6d] '%s'\n", token_id, th_ctx->vocab[token_id].c_str()); + } +#endif + } + for (const llama_token token_id: th_candidates) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } else { + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + LOG_TEE("\n"); + break; + } + + std::string new_token_piece = llama_token_to_piece(ctx, new_token_id); + LOG_TEE("%s", new_token_piece.c_str()); + fflush(stdout); + + if (token_healing_enabled) { + if (new_token_piece.size() < th_ctx->prefix.size()) { + // Shift prefix constraint (for multi step token healing) + th_ctx->prefix = th_ctx->prefix.substr(new_token_piece.size()); + } else { + th_ctx->prefix.clear(); + token_healing_enabled = false; + } + } + + // prepare the next batch + llama_batch_clear(batch); + + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); + + n_decode += 1; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + token_healing_free(th_ctx); + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +}