Skip to content

Commit

Permalink
grammar : restore llama_grammar_accept signature
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 6, 2024
1 parent 55a08ed commit f9762c6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 26 deletions.
12 changes: 6 additions & 6 deletions examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
const auto cpts = unicode_cpts_from_utf8(input_str);

const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);

size_t pos = 0;
for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy

cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);

if (cur_stacks.empty()) {
if (stacks_cur.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
cur_stacks = prev_stacks;
stacks_cur = stacks_prev;
return false;
}
++pos;
}

for (const auto & stack : cur_stacks) {
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
return true;
}
Expand Down
19 changes: 10 additions & 9 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks;
}

llama_grammar_stacks llama_grammar_accept(
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr) {
llama_grammar_stacks result;
result.reserve(stacks.size());
const uint32_t chr,
llama_grammar_stacks & stacks_new) {
stacks_new.clear();
stacks_new.reserve(stacks.size());

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, result);
llama_grammar_advance_stack(rules, new_stack, stacks_new);
}
}

return result;
}

llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
Expand Down Expand Up @@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;

llama_grammar_stacks stacks_new;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
grammar.stacks = std::move(new_stacks);
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
grammar.stacks = std::move(stacks_new);
}

grammar.partial_utf8 = decoded.second;
Expand Down
5 changes: 3 additions & 2 deletions src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
llama_grammar_stacks llama_grammar_accept(
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
uint32_t chr);
uint32_t chr,
llama_grammar_stacks & stacks_new);

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
Expand Down
18 changes: 9 additions & 9 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input);

const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);

for (const auto & cpt : cpts) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy

cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);

if (cur_stacks.empty()) {
if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}

for (const auto & stack : cur_stacks) {
for (const auto & stack : stacks_cur) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
Expand All @@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
auto * grammar = build_grammar(grammar_str);

// Save the original grammar stacks so that we can reset after every new string we want to test
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);

llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);

fprintf(stderr, " 🔵 Valid strings:\n");

Expand Down Expand Up @@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(matched);

// Reset the grammar stacks
cur_stacks = original_stacks;
stacks_cur = stacks_org;
}

fprintf(stderr, " 🟠 Invalid strings:\n");
Expand All @@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
assert(!matched);

// Reset the grammar stacks
cur_stacks = original_stacks;
stacks_cur = stacks_org;
}

// Clean up allocated memory
Expand Down

0 comments on commit f9762c6

Please sign in to comment.