Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grammars: x{min,max} repetition operator #6640

Merged
merged 36 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0160469
grammars: x{min,max} repetition operator + tweak +/*/? to avoid dupli…
ochafik Apr 12, 2024
f2030e3
grammars: handle `x{n}` and fix `x{n,n}`
ochafik Apr 12, 2024
de0fd3f
grammars: document new repetition operators
ochafik Apr 12, 2024
9d9b5a3
grammars: nit
ochafik Apr 12, 2024
6b5518c
grammars: uniform use of int for min & max
ochafik Apr 12, 2024
0ceb69a
grammars: refactor parser test
ochafik Apr 12, 2024
8938a05
grammar: parsing tests w/ natural pretty print of updated expectations
ochafik Apr 12, 2024
0d7347f
grammars: much prettier print of expectations (+ TEST_GRAMMAR_PARSER_…
ochafik Apr 12, 2024
2e2df72
grammars: improve test pretty print again
ochafik Apr 12, 2024
ffe321d
grammars: pretty print rules and chars
ochafik Apr 12, 2024
a9351b8
grammars: fix copy rule skipping
ochafik Apr 12, 2024
9d8efa5
grammars: disallow `a{,}` (not allowed in regexps)
ochafik Apr 12, 2024
2d98ebf
Update common/grammar-parser.cpp
ochafik Apr 12, 2024
ec91342
grammars: fix copy rule skipping (again) & display of expectations
ochafik Apr 12, 2024
22faba6
grammars: more test cases
ochafik Apr 12, 2024
1fb7787
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik Apr 15, 2024
15585e0
grammars: update reps parsing to bring ? / * / + closer to before
ochafik Apr 19, 2024
93b754e
json: use new GBNF repetitions{m,n} syntax
ochafik Apr 19, 2024
2ecc2ae
grammars: update performance gotchas w/ repetition advice
ochafik Apr 20, 2024
a9a2983
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik Apr 21, 2024
d47f537
Update examples/json_schema_to_grammar.py
ochafik Apr 24, 2024
724f879
Update examples/server/public/json-schema-to-grammar.mjs
ochafik Apr 24, 2024
a61281f
grammars: comment on rule repetitions
ochafik Apr 24, 2024
d03c98e
grammars: ensure unambiguous number alternatives
ochafik Apr 24, 2024
21bac1e
grammar: nit typo switched error msgs
ochafik Apr 24, 2024
0c74ad3
grammar: nit numbering in comment
ochafik Apr 24, 2024
218f41f
json: update numeric rule to be unambiguous
ochafik Apr 24, 2024
2813835
Apply suggestions from code review
ochafik Apr 24, 2024
46fe648
Update examples/server/public/json-schema-to-grammar.mjs
ochafik Apr 24, 2024
eb7ccd8
json: fix integral-part
ochafik Apr 24, 2024
3c02508
Merge branch 'grammar-reps' of https://github.com/ochafik/llama.cpp i…
ochafik Apr 24, 2024
476c97d
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik Apr 30, 2024
990bf57
grammar: add repetition tests
ochafik Apr 30, 2024
d070aee
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik May 18, 2024
8266b7c
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik May 21, 2024
2b79d47
Merge remote-tracking branch 'origin/master' into grammar-reps
ochafik Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 107 additions & 31 deletions common/grammar-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ namespace grammar_parser {
state.rules[rule_id] = rule;
}

static bool is_digit_char(char c) {
return '0' <= c && c <= '9';
}

static bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
}

static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
Expand Down Expand Up @@ -99,6 +103,17 @@ namespace grammar_parser {
return pos;
}

static const char * parse_int(const char * src) {
const char * pos = src;
while (is_digit_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::runtime_error(std::string("expecting integer at ") + src);
}
return pos;
}

static std::pair<uint32_t, const char *> parse_char(const char * src) {
if (*src == '\\') {
switch (src[1]) {
Expand Down Expand Up @@ -137,6 +152,60 @@ namespace grammar_parser {
bool is_nested) {
size_t last_sym_start = out_elements.size();
const char * pos = src;

auto handle_repetitions = [&](int min_times, int max_times) {
HanClinto marked this conversation as resolved.
Show resolved Hide resolved

if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}

// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |

std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
if (min_times == 0) {
out_elements.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
}
}

uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;

std::vector<llama_grammar_element> rec_rule(previous_elements);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(previous_elements.size());
uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};

while (*pos) {
if (*pos == '"') { // literal string
pos++;
Expand Down Expand Up @@ -197,40 +266,47 @@ namespace grammar_parser {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
}
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);

// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<llama_grammar_element> sub_rule;
// add preceding symbol to generated rule
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
}
// mark start of alternate def
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
if (*pos == '+') {
// add preceding symbol as alternate only for '+' (otherwise empty)
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);

// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
int max_times = -1;

pos = parse_space(pos + 1, is_nested);
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);

if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}

if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
}
Expand Down
78 changes: 20 additions & 58 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,27 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa

static std::string repeat(const std::string & str, size_t n);

static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) {
if (separator_rule.empty()) {
if (min_items == 0 && max_items == 1) {
return item_rule + "?";
} else if (min_items == 1 && max_items == std::numeric_limits<int>::max()) {
return item_rule + "+";
}
}
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();

std::string result;
if (min_items > 0) {
if (item_rule_is_literal && separator_rule.empty()) {
result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\"";
} else {
std::vector<std::string> items(min_items, item_rule);
result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " ");
}
if (min_items == 0 && max_items == 1) {
return item_rule + "?";
}

std::function<std::string(int, bool)> opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string {
auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule;

if (up_to_n == 0) {
return "";
} else if (up_to_n == 1) {
return "(" + content + ")?";
} else if (!separator_rule.empty() && !prefix_with_sep) {
return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?";
if (separator_rule.empty()) {
if (min_items == 1 && !has_max) {
return item_rule + "+";
} else if (min_items == 0 && !has_max) {
return item_rule + "*";
} else {
std::string res = repeat("(" + content + " ", up_to_n);
// strip trailing space
res = res.substr(0, res.length() - 1);
res += repeat(")?", up_to_n);
return res;
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
}
};

if (min_items > 0 && max_items != min_items) {
result += " ";
}

if (max_items != std::numeric_limits<int>::max()) {
result += opt_repetitions(max_items - min_items, min_items > 0);
} else {
std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")";
if (min_items == 0 && !separator_rule.empty()) {
result = "(" + item_rule + " " + item_operator + "*)?";
} else {
result += item_operator + "*";
}
auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
if (min_items == 0) {
result = "(" + result + ")?";
}

return result;
}

Expand All @@ -78,30 +47,24 @@ struct BuiltinRule {
std::vector<std::string> deps;
};

const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15);

std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"boolean", {"(\"true\" | \"false\") space", {}}},
{"decimal-part", {"[0-9] " + _up_to_15_digits, {}}},
{"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}},
{"decimal-part", {"[0-9]{1,16}", {}}},
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}},
{"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
{"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
{"null", {"\"null\" space", {}}},
};

std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
{"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
{"date-time", {"date \"T\" time", {"date", "time"}}},
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
Expand Down Expand Up @@ -385,8 +348,7 @@ class SchemaConverter {
sub_is_literal ? "\"" + sub + "\"" : sub,
min_times,
max_times,
"",
sub_is_literal
""
);
seq.back().second = false;
} else {
Expand Down
Loading
Loading