From 267f13852f494e95a99b323b7530bd4a29393377 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Aug 2024 17:29:24 +0300 Subject: [PATCH] grammar : hide decode_utf8 overload ggml-ci --- common/sampling.cpp | 18 +- examples/gbnf-validator/gbnf-validator.cpp | 13 +- src/llama-grammar.cpp | 310 ++++++++++----------- src/llama-grammar.h | 12 +- src/llama-sampling.cpp | 6 +- tests/test-grammar-integration.cpp | 11 +- 6 files changed, 179 insertions(+), 191 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 69ce4e9040341..6079109ffb0b7 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -157,25 +157,21 @@ std::vector llama_sampling_types_from_names(const std::vecto std::vector sampler_types; sampler_types.reserve(names.size()); - for (const auto & name : names) - { + + for (const auto & name : names) { auto sampler_item = sampler_canonical_name_map.find(name); - if (sampler_item != sampler_canonical_name_map.end()) - { + if (sampler_item != sampler_canonical_name_map.end()) { sampler_types.push_back(sampler_item->second); - } - else - { - if (allow_alt_names) - { + } else { + if (allow_alt_names) { sampler_item = sampler_alt_name_map.find(name); - if (sampler_item != sampler_alt_name_map.end()) - { + if (sampler_item != sampler_alt_name_map.end()) { sampler_types.push_back(sampler_item->second); } } } } + return sampler_types; } diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index f4868a55741bd..f439c0c5648a8 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,7 +1,5 @@ -#include "ggml.h" -#include "llama.h" -#include "llama-grammar.h" #include "unicode.h" +#include "llama-grammar.h" #include #include @@ -11,21 +9,20 @@ #include static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { - auto decoded = decode_utf8(input_str, {}); - const auto & code_points = decoded.first; + const auto cpts = unicode_cpts_from_utf8(input_str); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); size_t pos = 0; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, *it); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { error_pos = pos; - error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'"; + error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; cur_stacks = prev_stacks; return false; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 687aef1c7f8f4..f163fd51c354c 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -11,7 +11,6 @@ // // NOTE: assumes valid utf8 (but checks for overrun) -// TODO: deduplicate static std::pair decode_utf8(const char * src) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t first_byte = static_cast(*src); @@ -27,6 +26,66 @@ static std::pair decode_utf8(const char * src) { return std::make_pair(value, pos); } +static std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src.c_str(); + std::vector code_points; + + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(src.size() + 1); + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + static bool is_digit_char(char c) { return '0' <= c && c <= '9'; } @@ -750,9 +809,103 @@ static bool llama_grammar_detect_left_recursion( (*rules_in_progress)[rule_index] = false; (*rules_visited)[rule_index] = true; + return false; } +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +llama_grammar_stacks llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr) { + llama_grammar_stacks result; + result.reserve(stacks.size()); + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, result); + } + } + + return result; +} + +llama_grammar_candidates llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates) { + + llama_grammar_candidates rejects; + rejects.reserve(candidates.size()); + + if (stack.empty()) { + for (const auto & tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + llama_grammar_candidates next_candidates; + next_candidates.reserve(candidates.size()); + + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + llama_grammar_stacks next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (const auto & tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +//////////////////// + struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, const llama_grammar_element ** rules, @@ -893,7 +1046,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) { llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 }; // redirect elements in stacks to point to new rules @@ -977,156 +1130,3 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.partial_utf8 = decoded.second; GGML_ASSERT(!grammar.stacks.empty()); } - -//////////////////// - -const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { - return grammar->rules; -} - -llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { - return grammar->stacks; -} - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; - const char * pos = src.c_str(); - std::vector code_points; - - // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. - code_points.reserve(src.size() + 1); - uint32_t value = partial_start.value; - int n_remain = partial_start.n_remain; - - // continue previous decode, if applicable - while (*pos != 0 && n_remain > 0) { - uint8_t next_byte = static_cast(*pos); - if ((next_byte >> 6) != 2) { - // invalid sequence, abort - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); - } - value = (value << 6) + (next_byte & 0x3F); - ++pos; - --n_remain; - } - - if (partial_start.n_remain > 0 && n_remain == 0) { - code_points.push_back(value); - } - - // decode any subsequent utf-8 sequences, which may end in an incomplete one - while (*pos != 0) { - uint8_t first_byte = static_cast(*pos); - uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; - - if (n_remain < 0) { - // invalid sequence, abort - code_points.clear(); - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); - } - - uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; - - ++pos; - while (*pos != 0 && n_remain > 0) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - ++pos; - --n_remain; - } - if (n_remain == 0) { - code_points.push_back(value); - } - } - code_points.push_back(0); - - return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); -} - -llama_grammar_stacks llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr) { - llama_grammar_stacks result; - result.reserve(stacks.size()); - - for (const auto & stack : stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(rules, new_stack, result); - } - } - - return result; -} - -llama_grammar_candidates llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates) { - - llama_grammar_candidates rejects; - rejects.reserve(candidates.size()); - - if (stack.empty()) { - for (const auto & tok : candidates) { - if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { - rejects.push_back(tok); - } - } - return rejects; - } - - const llama_grammar_element * stack_pos = stack.back(); - - llama_grammar_candidates next_candidates; - next_candidates.reserve(candidates.size()); - - for (const auto & tok : candidates) { - if (*tok.code_points == 0) { - // reached end of full codepoints in token, reject iff it ended in a partial sequence - // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && - !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { - rejects.push_back(tok); - } - } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); - } else { - rejects.push_back(tok); - } - } - - const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; - - // update top of stack to next element, if any - llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { - stack_after.push_back(stack_pos_after); - } - llama_grammar_stacks next_stacks; - llama_grammar_advance_stack(rules, stack_after, next_stacks); - - auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); - for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); - } - - return rejects; -} diff --git a/src/llama-grammar.h b/src/llama-grammar.h index aa4868681fce5..04555c29bb264 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -61,12 +61,6 @@ using llama_grammar_candidates = std::vector; const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those @@ -77,8 +71,8 @@ llama_grammar_stacks llama_grammar_accept( uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, const llama_grammar_candidates & candidates); struct llama_grammar_parser { @@ -142,7 +136,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar); +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar); // TODO: move the API below as member functions of llama_grammar void llama_grammar_apply_impl( diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f0d93d89b4fa7..17d0085b35c72 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -60,7 +60,7 @@ struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smp result->logit_bias = smpl.logit_bias; if (smpl.grammar) { - result->grammar = llama_grammar_copy_impl(*smpl.grammar); + result->grammar = llama_grammar_cp_impl(*smpl.grammar); } result->rng = smpl.rng; @@ -450,13 +450,15 @@ void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_ } // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates->data[0].logit; + const double max_l_double = candidates->data[0].logit; + double cum_sum_double = 0.0; for (size_t i = 0; i < candidates->size; ++i) { double p = exp(candidates->data[i].logit - max_l_double); candidates->data[i].p = p; // Store the scaled probability cum_sum_double += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index b0af685eddece..788b02a6a5cd8 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,6 +2,7 @@ #undef NDEBUG #endif +#include "unicode.h" #include "llama-grammar.h" #include "json-schema-to-grammar.h" @@ -29,17 +30,15 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { } static bool match_string(const std::string & input, llama_grammar * grammar) { - auto decoded = decode_utf8(input, {}); - - const auto & code_points = decoded.first; + const auto cpts = unicode_cpts_from_utf8(input); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, *it); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { // no stacks means that the grammar failed to match at this point @@ -61,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str()); fflush(stderr); - auto grammar = build_grammar(grammar_str); + auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);