Skip to content

Commit

Permalink
grammar : clean-up
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 17, 2024
1 parent 42924ed commit 5adba1c
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 48 deletions.
10 changes: 4 additions & 6 deletions examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "ggml.h"
#include "llama.h"
#include "llama-vocab.h" // TMP
#include "llama-grammar.h"
#include "unicode.h"

Expand All @@ -11,7 +10,7 @@
#include <string>
#include <vector>

static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
auto decoded = decode_utf8(input_str, {});
const auto & code_points = decoded.first;

Expand All @@ -22,7 +21,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy

llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
cur_stacks = llama_grammar_accept(rules, prev_stacks, *it);

if (cur_stacks.empty()) {
error_pos = pos;
Expand Down Expand Up @@ -84,8 +83,7 @@ int main(int argc, char** argv) {
grammar_str = buffer.str();
}

llama_vocab vocab; // TMP
llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
Expand All @@ -102,7 +100,7 @@ int main(int argc, char** argv) {
// Validate the input string against the grammar
size_t error_pos;
std::string error_msg;
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);

if (is_valid) {
fprintf(stdout, "Input string is valid according to the grammar.\n");
Expand Down
43 changes: 24 additions & 19 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,10 @@ static bool llama_grammar_detect_left_recursion(
}

struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab & vocab,
const struct llama_vocab * vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
size_t n_rules,
size_t start_rule_index) {
const llama_grammar_element * pos;

// copy rule definitions into vectors
Expand Down Expand Up @@ -808,10 +808,10 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
}

struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
llama_grammar_parser parser;

// if there is a grammar, parse it
Expand Down Expand Up @@ -886,15 +886,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab,
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
}

void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar;
}

struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar{ grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };

// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
Expand All @@ -913,6 +913,8 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & gram
}

void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) {
GGML_ASSERT(grammar.vocab != nullptr);

bool allow_eog = false;
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
Expand All @@ -929,9 +931,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string & piece = grammar.vocab.cache_token_to_piece.at(id);
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);

if (llama_token_is_eog_impl(grammar.vocab, id)) {
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
if (!allow_eog) {
candidates->data[i].logit = -INFINITY;
}
Expand All @@ -950,7 +952,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
}

void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
if (llama_token_is_eog_impl(grammar.vocab, token)) {
GGML_ASSERT(grammar.vocab != nullptr);

if (llama_token_is_eog_impl(*grammar.vocab, token)) {
for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
Expand All @@ -959,16 +963,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error");
}

const std::string & piece = grammar.vocab.cache_token_to_piece.at(token);
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;

llama_grammar_stacks tmp_new_stacks;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks);
grammar.stacks = tmp_new_stacks;
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
grammar.stacks = std::move(new_stacks);
}

grammar.partial_utf8 = decoded.second;
Expand Down Expand Up @@ -1045,12 +1048,12 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
}

void llama_grammar_accept(
llama_grammar_stacks llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks) {
new_stacks.clear();
const uint32_t chr) {
llama_grammar_stacks result;
result.reserve(stacks.size());

for (const auto & stack : stacks) {
if (stack.empty()) {
Expand All @@ -1066,9 +1069,11 @@ void llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
llama_grammar_advance_stack(rules, new_stack, result);
}
}

return result;
}

llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
Expand Down
14 changes: 7 additions & 7 deletions src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
void llama_grammar_accept(
llama_grammar_stacks llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);
uint32_t chr);

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
Expand Down Expand Up @@ -113,7 +112,8 @@ struct llama_grammar_parser {
};

struct llama_grammar {
const llama_vocab & vocab;
// note: allow null vocab for testing (not great)
const llama_vocab * vocab;

const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
Expand All @@ -131,14 +131,14 @@ struct llama_grammar {
// internal API
//

// TODO: temporary until the tests are fixed
// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
const struct llama_vocab & vocab,
const struct llama_vocab * vocab,
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);

struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);

void llama_grammar_free_impl(struct llama_grammar * grammar);

Expand Down
4 changes: 2 additions & 2 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void llama_sampling_reset_impl(struct llama_sampling & smpl) {
}

if (!smpl.grammar_str.empty()) {
smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
}

smpl.prev.clear();
Expand All @@ -100,7 +100,7 @@ void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char *
smpl.grammar_str = grammar_str;
smpl.grammar_root = grammar_root;

smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root);
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root);
} else {
smpl.grammar_str.clear();
smpl.grammar_root.clear();
Expand Down
10 changes: 2 additions & 8 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
#undef NDEBUG
#endif

#include "ggml.h"
#include "llama.h"
#include "llama-vocab.h" // TMP
#include "llama-grammar.h"
#include "unicode.h"
#include "json-schema-to-grammar.h"

#include <cassert>
Expand All @@ -15,10 +11,8 @@

using json = nlohmann::ordered_json;

llama_vocab vocab; // TMP

static llama_grammar * build_grammar(const std::string & grammar_str) {
return llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
}

static bool test_build_grammar_fails(const std::string & grammar_str) {
Expand All @@ -45,7 +39,7 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy

llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
cur_stacks = llama_grammar_accept(rules, prev_stacks, *it);

if (cur_stacks.empty()) {
// no stacks means that the grammar failed to match at this point
Expand Down
10 changes: 4 additions & 6 deletions tests/test-llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#endif

#include "llama.h"
#include "llama-vocab.h" // TMP
#include "llama-grammar.h"

#include <cassert>
Expand Down Expand Up @@ -117,8 +116,7 @@ int main()
llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());

llama_vocab vocab; // TMP
grammar = llama_grammar_init_impl(vocab, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr)
{
throw std::runtime_error("Failed to initialize llama_grammar");
Expand Down Expand Up @@ -175,13 +173,13 @@ int main()
}};

auto index = 0;
for (auto stack : llama_grammar_get_stacks(grammar))
for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
{
auto element = stack[i];
auto expected_element = expected_stacks[index][i];
const llama_grammar_element * element = stack[i];
const llama_grammar_element & expected_element = expected_stacks[index][i];

// pretty print error message before asserting
if (expected_element.type != element->type || expected_element.value != element->value)
Expand Down

0 comments on commit 5adba1c

Please sign in to comment.