From 5adba1cf76892e3dba1185cc6dfe39dc412f9385 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 17 Aug 2024 10:40:52 +0300 Subject: [PATCH] grammar : clean-up ggml-ci --- examples/gbnf-validator/gbnf-validator.cpp | 10 ++--- src/llama-grammar.cpp | 43 ++++++++++++---------- src/llama-grammar.h | 14 +++---- src/llama-sampling.cpp | 4 +- tests/test-grammar-integration.cpp | 10 +---- tests/test-llama-grammar.cpp | 10 ++--- 6 files changed, 43 insertions(+), 48 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 9467fe4e2898e..f4868a55741bd 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,6 +1,5 @@ #include "ggml.h" #include "llama.h" -#include "llama-vocab.h" // TMP #include "llama-grammar.h" #include "unicode.h" @@ -11,7 +10,7 @@ #include #include -static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { +static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { auto decoded = decode_utf8(input_str, {}); const auto & code_points = decoded.first; @@ -22,7 +21,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, *it); if (cur_stacks.empty()) { error_pos = pos; @@ -84,8 +83,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - llama_vocab vocab; // TMP - llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root"); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } @@ -102,7 +100,7 @@ int main(int argc, char** argv) { // Validate the input string against the grammar size_t error_pos; std::string error_msg; - bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg); + bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg); if (is_valid) { fprintf(stdout, "Input string is valid according to the grammar.\n"); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 5c60246d90ca2..687aef1c7f8f4 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -754,10 +754,10 @@ static bool llama_grammar_detect_left_recursion( } struct llama_grammar * llama_grammar_init_impl( - const struct llama_vocab & vocab, + const struct llama_vocab * vocab, const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { + size_t n_rules, + size_t start_rule_index) { const llama_grammar_element * pos; // copy rule definitions into vectors @@ -808,10 +808,10 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -886,7 +886,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -894,7 +894,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar{ grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 }; + llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -913,6 +913,8 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & gram } void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) { + GGML_ASSERT(grammar.vocab != nullptr); + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -929,9 +931,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = grammar.vocab.cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(grammar.vocab, id)) { + if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } @@ -950,7 +952,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { - if (llama_token_is_eog_impl(grammar.vocab, token)) { + GGML_ASSERT(grammar.vocab != nullptr); + + if (llama_token_is_eog_impl(*grammar.vocab, token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; @@ -959,16 +963,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab.cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks); - grammar.stacks = tmp_new_stacks; + llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); + grammar.stacks = std::move(new_stacks); } grammar.partial_utf8 = decoded.second; @@ -1045,12 +1048,12 @@ std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } -void llama_grammar_accept( +llama_grammar_stacks llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); + const uint32_t chr) { + llama_grammar_stacks result; + result.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -1066,9 +1069,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, result); } } + + return result; } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 2fbb954088fc5..aa4868681fce5 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -71,11 +71,10 @@ std::pair, llama_partial_utf8> decode_utf8( // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( +llama_grammar_stacks llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); + uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -113,7 +112,8 @@ struct llama_grammar_parser { }; struct llama_grammar { - const llama_vocab & vocab; + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; @@ -131,14 +131,14 @@ struct llama_grammar { // internal API // -// TODO: temporary until the tests are fixed +// note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( - const struct llama_vocab & vocab, + const struct llama_vocab * vocab, const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 982319ba189cb..f0d93d89b4fa7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -76,7 +76,7 @@ void llama_sampling_reset_impl(struct llama_sampling & smpl) { } if (!smpl.grammar_str.empty()) { - smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); } smpl.prev.clear(); @@ -100,7 +100,7 @@ void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * smpl.grammar_str = grammar_str; smpl.grammar_root = grammar_root; - smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root); + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root); } else { smpl.grammar_str.clear(); smpl.grammar_root.clear(); diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index c5a4bc4954f74..b01791a923e37 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,11 +2,7 @@ #undef NDEBUG #endif -#include "ggml.h" -#include "llama.h" -#include "llama-vocab.h" // TMP #include "llama-grammar.h" -#include "unicode.h" #include "json-schema-to-grammar.h" #include @@ -15,10 +11,8 @@ using json = nlohmann::ordered_json; -llama_vocab vocab; // TMP - static llama_grammar * build_grammar(const std::string & grammar_str) { - return llama_grammar_init_impl(vocab, grammar_str.c_str(), "root"); + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -45,7 +39,7 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, *it); if (cur_stacks.empty()) { // no stacks means that the grammar failed to match at this point diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f4a1f1bc3432..6f1374ca8ed58 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -3,7 +3,6 @@ #endif #include "llama.h" -#include "llama-vocab.h" // TMP #include "llama-grammar.h" #include @@ -117,8 +116,7 @@ int main() llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - llama_vocab vocab; // TMP - grammar = llama_grammar_init_impl(vocab, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); @@ -175,13 +173,13 @@ int main() }}; auto index = 0; - for (auto stack : llama_grammar_get_stacks(grammar)) + for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar)) { // compare stack to expected_stack for (uint32_t i = 0; i < stack.size(); i++) { - auto element = stack[i]; - auto expected_element = expected_stacks[index][i]; + const llama_grammar_element * element = stack[i]; + const llama_grammar_element & expected_element = expected_stacks[index][i]; // pretty print error message before asserting if (expected_element.type != element->type || expected_element.value != element->value)