From 6a83ecb8d1845b1c19cf397787a550b9f9956073 Mon Sep 17 00:00:00 2001 From: Nexesenex <124105151+Nexesenex@users.noreply.github.com> Date: Wed, 23 Oct 2024 12:49:35 +0200 Subject: [PATCH] Revert " llama : adds llama-grammar memoization stacks (#4218) #9833" This reverts commit 4cbf5c392af62252a69e17143e8a81d771ca6f8a. --- examples/gbnf-validator/gbnf-validator.cpp | 11 +- src/llama-grammar.cpp | 136 +++------------------ src/llama-grammar.h | 23 +--- 3 files changed, 31 insertions(+), 139 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 17a0e27c444e8..7493af9d3aec3 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -11,15 +11,19 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { const auto cpts = unicode_cpts_from_utf8(input_str); - auto & stacks_cur = llama_grammar_get_stacks(grammar); + const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - llama_grammar_accept(grammar, cpt); + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy + + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; + stacks_cur = stacks_prev; return false; } ++pos; @@ -78,8 +82,7 @@ int main(int argc, char** argv) { llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { - fprintf(stdout, "Failed to initialize llama_grammar\n"); - return 1; + throw std::runtime_error("Failed to initialize llama_grammar"); } // Read the input file std::string input_str; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 53f49dcbf300b..74e9f64b393b2 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -682,101 +682,6 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } -// transforms a grammar pushdown stack into N possible stacks, all ending -// at a character range (terminal element) -// additionally memoizes the stack to its possible stacks by mapping -// < llama_grammar_stack, llama_grammar_stacks > - -static void llama_grammar_advance_stack_memo( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks, - llama_grammar_stacks_cache & stacks_cache); - -static void llama_grammar_advance_stack_memo_impl( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks, - llama_grammar_stacks_cache & stacks_cache) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); - } - return; - } - - const llama_grammar_element * pos = stack.back(); - - switch (pos->type) { - case LLAMA_GRETYPE_RULE_REF: { - const size_t rule_id = static_cast(pos->value); - const llama_grammar_element * subpos = rules[rule_id].data(); - do { - // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos + 1)) { - // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); - } - if (!llama_grammar_is_end_of_sequence(subpos)) { - // if alternate is nonempty, add to stack - new_stack.push_back(subpos); - } - llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache); - while (!llama_grammar_is_end_of_sequence(subpos)) { - // scan to end of alternate def - subpos++; - } - if (subpos->type == LLAMA_GRETYPE_ALT) { - // there's another alternate def of this rule to process - subpos++; - } else { - break; - } - } while (true); - break; - } - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_ANY: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); - } - break; - default: - // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range - // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on - // those - GGML_ABORT("fatal error"); - } -} - -static void llama_grammar_advance_stack_memo( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks, - llama_grammar_stacks_cache & stacks_cache) { - - llama_grammar_stacks advanced_stacks; - // Look if stack is already in memory - auto it = stacks_cache.find(stack); - if (it != stacks_cache.end()) { - advanced_stacks = it->second; - } else { - // Advance stacks with memoization - llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); - stacks_cache.insert(make_pair(stack, advanced_stacks)); - } - // Add the advanced stacks to new_stacks avoiding duplicates - for (const auto & new_stack : advanced_stacks) { - if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) { - new_stacks.emplace_back(new_stack); - } - } - -} - // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -917,11 +822,15 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { - llama_grammar_stacks stacks_new; - stacks_new.reserve(grammar->stacks.size()); +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr, + llama_grammar_stacks & stacks_new) { + stacks_new.clear(); + stacks_new.reserve(stacks.size()); - for (const auto & stack : grammar->stacks) { + for (const auto & stack : stacks) { if (stack.empty()) { continue; } @@ -935,11 +844,9 @@ void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache); + llama_grammar_advance_stack(rules, new_stack, stacks_new); } } - - grammar->stacks = std::move(stacks_new); } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1031,7 +938,6 @@ struct llama_grammar * llama_grammar_init_impl( // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; - llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1039,7 +945,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); + llama_grammar_advance_stack(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1055,7 +961,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { @@ -1110,7 +1016,6 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; - llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1118,7 +1023,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); + llama_grammar_advance_stack(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1134,7 +1039,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1146,13 +1051,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { - grammar.vocab, - grammar.rules, - grammar.stacks, - grammar.stacks_cache, - grammar.partial_utf8, - }; + llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -1160,7 +1059,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1227,8 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; + llama_grammar_stacks stacks_new; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(&grammar, *it); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + grammar.stacks = std::move(stacks_new); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 3f13fee4ff6f0..f529ce351e416 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,7 +3,6 @@ #include "llama-impl.h" #include -#include struct llama_vocab; @@ -59,19 +58,6 @@ using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; -struct VectorPointerHash { - size_t operator()(const llama_grammar_stack & v) const { - size_t seed = v.size(); - for (const auto* ptr : v) { - seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - -using llama_grammar_stacks_cache = std::unordered_map; - -// TODO: remove, needed for tests atm const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); @@ -79,7 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + uint32_t chr, + llama_grammar_stacks & stacks_new); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -123,9 +113,6 @@ struct llama_grammar { const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; - // cache N possible stacks from a stack - llama_grammar_stacks_cache stacks_cache; - // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; };