Skip to content

Commit

Permalink
json: refine constraint for whitespace to avoid runaways yet allow pr…
Browse files Browse the repository at this point in the history
…etty print (ggerganov#7866)
  • Loading branch information
ochafik authored Jun 11, 2024
1 parent 396b18d commit b61eb96
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 45 deletions.
2 changes: 1 addition & 1 deletion common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}

const std::string SPACE_RULE = "\" \"?";
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";

struct BuiltinRule {
std::string content;
Expand Down
5 changes: 2 additions & 3 deletions examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def __init__(self, content: str, deps: list = None):
self.content = content
self.deps = deps or []

# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
# Constraining spaces to prevent model "running away".
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'

PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
Expand Down
2 changes: 1 addition & 1 deletion examples/server/public/json-schema-to-grammar.mjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '" "?';
const SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}';

function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (minItems === 0 && maxItems === 1) {
Expand Down
2 changes: 1 addition & 1 deletion grammars/json.gbnf
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ string ::=
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= [ \t\n]{0,20}
ws ::= | " " | "\n" [ \t]{0,20}
2 changes: 1 addition & 1 deletion grammars/json_arr.gbnf
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ string ::=
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [1-9] [0-9]{0,15})? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= [ \t\n]{0,20}
ws ::= | " " | "\n" [ \t]{0,20}
76 changes: 38 additions & 38 deletions tests/test-json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
Expand All @@ -135,7 +135,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
date-time ::= date "T" time
date-time-string ::= "\"" date-time "\"" space
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
time-string ::= "\"" time "\"" space
tuple-0 ::= date-string
Expand All @@ -154,7 +154,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char* "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -168,7 +168,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char+ "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -182,7 +182,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{3,} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -196,7 +196,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{0,3} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -211,7 +211,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "\"" char{1,4} "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -223,7 +223,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= ("true" | "false") space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -236,7 +236,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -248,7 +248,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"foo\""
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -260,7 +260,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "123"
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -272,7 +272,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -285,7 +285,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "[" space string "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -302,7 +302,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space string "," space number "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -317,7 +317,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
decimal-part ::= [0-9]{1,16}
integral-part ::= [0] | [1-9] [0-9]{0,15}
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -333,7 +333,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean ("," space boolean)+ "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -349,7 +349,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean? "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -365,7 +365,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space (boolean ("," space boolean)?)? "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -386,7 +386,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
item ::= number | integer
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "[" space item ("," space item){2,4} "]" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -399,7 +399,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -412,7 +412,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "[]{}()|+*?" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "\"" "\"" "\"" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
dot ::= [^\x0A\x0D]
root ::= "\"" ("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot "\"" space
root-1 ::= [0-9]
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand Down Expand Up @@ -468,7 +468,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -488,7 +488,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -512,7 +512,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
c-kv ::= "\"c\"" space ":" space string
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -538,7 +538,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
d-kv ::= "\"d\"" space ":" space string
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -559,7 +559,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (additional-kvs )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -581,7 +581,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
Expand All @@ -603,7 +603,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
value ::= object | array | string | number | boolean | null
)"""
Expand All @@ -618,7 +618,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
})""",
R"""(
root ::= "{" space "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand All @@ -642,7 +642,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand All @@ -667,7 +667,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand Down Expand Up @@ -695,7 +695,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand Down Expand Up @@ -725,7 +725,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
foo ::= "{" space foo-a-kv "}" space
foo-a-kv ::= "\"a\"" space ":" space string
root ::= foo
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
string ::= "\"" char* "\"" space
)"""
});
Expand Down Expand Up @@ -759,7 +759,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= alternative-0 | alternative-1
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand Down Expand Up @@ -803,7 +803,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
integral-part ::= [0] | [1-9] [0-9]{0,15}
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});

Expand Down Expand Up @@ -851,7 +851,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
number-number-kv ::= "\"number\"" space ":" space number-number
number-number-root-kv ::= "\"root\"" space ":" space number
root ::= "{" space number-kv "}" space
space ::= " "?
space ::= | " " | "\n" [ \t]{0,20}
)"""
});
}
Expand Down

0 comments on commit b61eb96

Please sign in to comment.