Skip to content

Commit

Permalink
json: fix additionalProperties, allow space after enum/const (ggerg…
Browse files Browse the repository at this point in the history
…anov#7840)

* json: default additionalProperty to true

* json: don't force additional props after normal properties!

* json: allow space after enum/const

* json: update pydantic example to set additionalProperties: false

* json: prevent additional props to redefine a typed prop

* port not_strings to python, add trailing space

* fix not_strings & port to js+py

* Update json-schema-to-grammar.cpp

* fix _not_strings for substring overlaps

* json: fix additionalProperties default, uncomment tests

* json: add integ. test case for additionalProperties

* json: nit: simplify condition

* reformat grammar integ tests w/ R"""()""" strings where there's escapes

* update # tokens in server test: consts can now have trailing space
  • Loading branch information
ochafik authored Jun 26, 2024
1 parent 163d50a commit 6777c54
Show file tree
Hide file tree
Showing 7 changed files with 497 additions and 245 deletions.
99 changes: 86 additions & 13 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,75 @@ class SchemaConverter {
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
}

/*
Returns a rule that matches a JSON string that is none of the provided strings
not_strings({"a"})
-> ["] ( [a] char+ | [^"a] char* )? ["] space
not_strings({"and", "also"})
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
*/
std::string _not_strings(const std::vector<std::string> & strings) {

struct TrieNode {
std::map<char, TrieNode> children;
bool is_end_of_string;

TrieNode() : is_end_of_string(false) {}

void insert(const std::string & string) {
auto node = this;
for (char c : string) {
node = &node->children[c];
}
node->is_end_of_string = true;
}
};

TrieNode trie;
for (const auto & s : strings) {
trie.insert(s);
}

std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
std::ostringstream out;
out << "[\"] ( ";
std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
std::ostringstream rejects;
auto first = true;
for (const auto & kv : node.children) {
rejects << kv.first;
if (first) {
first = false;
} else {
out << " | ";
}
out << "[" << kv.first << "]";
if (!kv.second.children.empty()) {
out << " (";
visit(kv.second);
out << ")";
} else if (kv.second.is_end_of_string) {
out << " " << char_rule << "+";
}
}
if (!node.children.empty()) {
if (!first) {
out << " | ";
}
out << "[^\"" << rejects.str() << "] " << char_rule << "*";
}
};
visit(trie);

out << " )";
if (!trie.is_end_of_string) {
out << "?";
}
out << " [\"] space";
return out.str();
}

std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
Expand All @@ -634,6 +703,7 @@ class SchemaConverter {
std::vector<std::string> required_props;
std::vector<std::string> optional_props;
std::unordered_map<std::string, std::string> prop_kv_rule_names;
std::vector<std::string> prop_names;
for (const auto & kv : properties) {
const auto &prop_name = kv.first;
const auto &prop_schema = kv.second;
Expand All @@ -648,11 +718,18 @@ class SchemaConverter {
} else {
optional_props.push_back(prop_name);
}
prop_names.push_back(prop_name);
}
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
if (!(additional_properties.is_boolean() && !additional_properties.get<bool>())) {
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
std::string value_rule =
additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
: _add_primitive("value", PRIMITIVE_RULES.at("value"));

auto key_rule =
prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
: _add_rule(sub_name + "-k", _not_strings(prop_names));
std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*");
}
Expand All @@ -678,15 +755,11 @@ class SchemaConverter {
}
std::string k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k];
if (k == "*") {
res = _add_rule(
name + (name.empty() ? "" : "-") + "additional-kvs",
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
);
} else if (first_is_optional) {
res = "( \",\" space " + kv_rule_name + " )?";
std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
if (first_is_optional) {
res = comma_ref + (k == "*" ? "*" : "?");
} else {
res = kv_rule_name;
res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
}
if (ks.size() > 1) {
res += " " + _add_rule(
Expand Down Expand Up @@ -824,13 +897,13 @@ class SchemaConverter {
}
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
} else if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
Expand Down
6 changes: 5 additions & 1 deletion examples/json-schema-pydantic-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#! pip install pydantic
#! python json-schema-pydantic-example.py

from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, Extra, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, List, Optional
import json, requests
Expand Down Expand Up @@ -50,12 +50,16 @@ def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1
if __name__ == '__main__':

class QAPair(BaseModel):
class Config:
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
question: str
concise_answer: str
justification: str
stars: Annotated[int, Field(ge=1, le=5)]

class PyramidalSummary(BaseModel):
class Config:
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
title: str
summary: str
question_answers: Annotated[List[QAPair], MinLen(2)]
Expand Down
76 changes: 60 additions & 16 deletions examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import json
import re
import sys
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from typing import Any, List, Optional, Set, Tuple, Union

def _build_repetition(item_rule, min_items, max_items, separator_rule=None):

Expand Down Expand Up @@ -276,6 +275,51 @@ def recurse(i: int):

return ''.join(('(', *recurse(0), ')'))

def _not_strings(self, strings):
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_string = False

def insert(self, string):
node = self
for c in string:
node = node.children.setdefault(c, TrieNode())
node.is_end_of_string = True

trie = TrieNode()
for s in strings:
trie.insert(s)

char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
out = ['["] ( ']

def visit(node):
rejects = []
first = True
for c in sorted(node.children.keys()):
child = node.children[c]
rejects.append(c)
if first:
first = False
else:
out.append(' | ')
out.append(f'[{c}]')
if child.children:
out.append(f' (')
visit(child)
out.append(')')
elif child.is_end_of_string:
out.append(f' {char_rule}+')
if node.children:
if not first:
out.append(' | ')
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
visit(trie)

out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
return ''.join(out)

def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
Expand Down Expand Up @@ -524,10 +568,10 @@ def visit(self, schema, name):
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))

elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')

elif 'enum' in schema:
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
return self._add_rule(rule_name, rule)

elif schema_type in (None, 'object') and \
Expand Down Expand Up @@ -632,7 +676,7 @@ def _add_primitive(self, name: str, rule: BuiltinRule):
self._add_primitive(dep, dep_rule)
return n

def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
Expand All @@ -647,12 +691,16 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]

if additional_properties == True or isinstance(additional_properties, dict):
if additional_properties != False:
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
self._add_primitive('value', PRIMITIVE_RULES['value'])
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))

prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
f'{key_rule} ":" space {value_rule}'
)
optional_props.append("*")

Expand All @@ -667,15 +715,11 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == '*':
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
comma_ref = f'( "," space {kv_rule_name} )'
if first_is_optional:
res = comma_ref + ('*' if k == '*' else '?')
else:
res = kv_rule_name
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
Expand Down
Loading

0 comments on commit 6777c54

Please sign in to comment.