Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes in JsonSchema oneOf parsing, better multilingual testing #144

Merged
merged 3 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ def get_allowed_characters(self) -> str:
def can_end(self) -> bool:
return any([parser.can_end() for parser in self.parsers])

def shortcut_key(self) -> Optional[str]:
return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None
def shortcut_key(self) -> Optional[Hashable]:
unique_shortcut_keys = set(parser.shortcut_key() for parser in self.parsers)
if len(unique_shortcut_keys) == 1:
return next(iter(unique_shortcut_keys))
return None

def cache_key(self) -> Optional[Hashable]:
all_cache_keys = tuple(parser.cache_key() for parser in self.parsers)
Expand Down
15 changes: 8 additions & 7 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def __init__(self, root: JsonSchemaParser):


def _merge_object_schemas(base_schema: JsonSchemaObject, option_schema: JsonSchemaObject) -> JsonSchemaObject:
for property_name, property_value in base_schema.properties.items():
base_schema_properties = base_schema.properties or {}
for property_name, property_value in base_schema_properties.items():
# We assume that if a property exists in both base and option, the option version will be
# more specific, therefore we only take missing entries
if property_name not in option_schema.properties:
Expand Down Expand Up @@ -201,13 +202,13 @@ def get_parser(
max_length=value_schema.maxLength,
pattern=value_schema.pattern,
)
if value_schema.oneOf:
# We create a combined object schema for each option that includes the information from the parent
# And then create a UnionParser based on the combined options
merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf]
object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas]
return UnionParser(object_parsing_options)
elif value_schema.type == "object":
if value_schema.oneOf:
# We create a combined object schema for each option that includes the information from the parent
# And then create a UnionParser based on the combined options
merged_schemas = [_merge_object_schemas(value_schema, option_schema) for option_schema in value_schema.oneOf]
object_parsing_options = [ObjectParsingState(merged_schema, parsing_state) for merged_schema in merged_schemas]
return UnionParser(object_parsing_options)
return ObjectParsingState(value_schema, parsing_state)
elif value_schema.type == None and value_schema.ref:
value_class_name = value_schema.ref.split('/')[-1]
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ types-setuptools = "68.1.0.1"
[tool.poetry.group.tests.dependencies]
pytest = {version = "6.2.5", python = ">=3.8"}
coverage = {version = "^7.3.1", python = ">=3.8", extras = ["toml"]}
transformers = ">=4.28.1"
transformers = ">=4.37.0"
torch = {version = "^2.1.0+cpu", source = "pytorch"}
numpy = "^1.21.0"

[tool.poetry.group.samples.dependencies]
Flask = {version = "2.3.2", python = ">=3.8"}
transformers = ">=4.28.1"
transformers = ">=4.37.0"
tokenizers = ">=0.13.3"


Expand Down
17 changes: 13 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from pstats import Stats
from typing import Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from lmformatenforcer import CharacterLevelParser
from lmformatenforcer.exceptions import LMFormatEnforcerException
from lmformatenforcer.tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
from lmformatenforcer.integrations.transformers import build_token_enforcer_tokenizer_data

import logging


_tokenizer: Optional[PreTrainedTokenizerBase] = None
_tokenizer_data: Optional[TokenEnforcerTokenizerData] = None
Expand Down Expand Up @@ -40,10 +40,19 @@ def assert_parser_with_string_direct(string: str, parser: CharacterLevelParser,
def assert_parser_with_string_token_enforcer(string: str, parser: CharacterLevelParser, expect_success: bool, profile_file_path: Optional[str]):
global _tokenizer
if _tokenizer is None:
model_id = 'TheBloke/Llama-2-7b-Chat-GPTQ'
model_id = 'Qwen/Qwen2.5-72B-Instruct'
_tokenizer = AutoTokenizer.from_pretrained(model_id)

global _tokenizer_data

# For testing, we make sure that all letters exist individually in the tokenizer
encoded_0 = _tokenizer.encode("0")
for word in set(string):
encoded_word = _tokenizer.encode(word)
if len(encoded_word) > len(encoded_0):
logging.basicConfig(level=logging.INFO)
logging.warning("Encountered out-of-tokenizer character, LMFE does not deal with this well")

if _tokenizer_data is None:
_tokenizer_data = build_token_enforcer_tokenizer_data(_tokenizer)

Expand Down
45 changes: 34 additions & 11 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,37 @@
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH
from lmformatenforcer.characterlevelparser import CharacterLevelParserConfig
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, COMPLETE_ALPHABET, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH

from .common import assert_parser_with_string, CharacterNotAllowedException


def _test_json_schema_parsing_with_string(string: str, schema_dict: Optional[dict], expect_success: bool, profile_file_path: Optional[str] = None):
parser = JsonSchemaParser(schema_dict)
def _test_json_schema_parsing_with_string(string: str,
schema_dict: Optional[dict],
expect_success: bool,
profile_file_path: Optional[str] = None,
ensure_ascii_in_json_dumps: bool = False):
alphabet = COMPLETE_ALPHABET
for letter in set(string):
if letter not in alphabet and letter != '\n':
alphabet += letter
if expect_success:
try:
minified = json.dumps(json.loads(string), separators=(',', ':'), ensure_ascii=False)
for letter in set(minified):
if letter not in alphabet and letter != '\n':
alphabet += letter
except:
pass
config = CharacterLevelParserConfig(alphabet=alphabet)
parser = JsonSchemaParser(schema_dict, config=config)
assert_parser_with_string(string, parser, expect_success, profile_file_path)
if expect_success:
# If expecting success, also check minified and pretty-printed
minified = json.dumps(json.loads(string), separators=(',', ':'))
minified = json.dumps(json.loads(string), separators=(',', ':'), ensure_ascii=ensure_ascii_in_json_dumps)
assert_parser_with_string(minified, parser, expect_success)
pretty_printed = json.dumps(json.loads(string), indent=2)
pretty_printed = json.dumps(json.loads(string), indent=2, ensure_ascii=ensure_ascii_in_json_dumps)
assert_parser_with_string(pretty_printed, parser, expect_success)


Expand Down Expand Up @@ -190,22 +208,22 @@ class ListOfNoMinLengthModel(BaseModel):
def test_string_escaping():
for escaping_character in BACKSLASH_ESCAPING_CHARACTERS:
test_string = f'{{"num":1,"message":"hello {BACKSLASH}{escaping_character} world"}}'
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True)
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True, ensure_ascii_in_json_dumps=True)
for non_escaping_character in 'a1?':
test_string = f'{{"num":1,"message":"hello {BACKSLASH}{non_escaping_character} world"}}'
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)

# Unicode
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f0 world"}}'
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True)
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), True, ensure_ascii_in_json_dumps=True)

# Not enough unicode digits
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f world"}}'
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)

# Unicode digit outside of hex range
test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9fP world"}}'
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False)
_test_json_schema_parsing_with_string(test_string, SampleModel.model_json_schema(), False, ensure_ascii_in_json_dumps=True)


def test_comma_after_all_object_keys_fails():
Expand Down Expand Up @@ -774,4 +792,9 @@ def test_invalid_number_formats_with_leading_zeros(test_input):
('{"value": -9007199254740992}', True),
])
def test_number_edge_cases(test_input, expected_success):
_test_json_schema_parsing_with_string(test_input, schema, expected_success)
_test_json_schema_parsing_with_string(test_input, schema, expected_success)

def test_chinese_oneof_schema():
test_schema = { "$schema": "http://json-schema.org/draft-07/schema#", "type": "array", "items": { "oneOf": [ { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "公司上市" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "上市公司", "证券代码", "环节", "披露时间", "发行价格", "事件时间", "市值", "募资金额" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] }, { "type": "object", "properties": { "trigger": { "type": "string" }, "event_type": { "enum": [ "被约谈" ] }, "arguments": { "type": "array", "items": { "type": "object", "properties": { "role": { "enum": [ "公司名称", "披露时间", "被约谈时间", "约谈机构" ] }, "argument": { "type": "string" } }, "required": [ "role", "argument" ] } } }, "required": [ "trigger", "event_type", "arguments" ] } ] } }
correct_output = """[{"trigger": "IPO", "event_type": "公司上市", "arguments": [{"role": "上市公司", "argument": "理想汽车"}, {"role": "披露时间", "argument": "30日"}, {"role": "发行价格", "argument": "8-10美元"}, {"role": "环节", "argument": "筹备上市"}]}]"""
_test_json_schema_parsing_with_string(correct_output, test_schema, True)
Loading