diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 4fbb81f..9c19f7c 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -274,7 +274,7 @@ def get_allowed_characters(self) -> str: ) required_keys = self.schema_object.required or [] can_end = set(self.existing_keys).issuperset(required_keys) - can_parse_key = self.is_dictionary or set(possible_keys).issuperset( + can_parse_key = self.is_dictionary or set(possible_keys).difference( self.existing_keys ) diff --git a/tests/common.py b/tests/common.py index c2d8dcd..a7d2d71 100644 --- a/tests/common.py +++ b/tests/common.py @@ -2,6 +2,10 @@ from lmformatenforcer.exceptions import LMFormatEnforcerException +class CharacterNotAllowedException(LMFormatEnforcerException): + pass + + def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_success: bool): for idx, character in enumerate(string): try: @@ -9,7 +13,7 @@ def assert_parser_with_string(string: str, parser: CharacterLevelParser, expect_ parser = parser.add_character(character) else: if expect_success: - raise ValueError(f"Parser does not allow '{character}' at index {idx}") + raise CharacterNotAllowedException(f"Parser does not allow '{character}' at index {idx}") else: return # Success except LMFormatEnforcerException: diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index 86973f0..11d0f44 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -8,7 +8,7 @@ from lmformatenforcer.exceptions import LMFormatEnforcerException -from .common import assert_parser_with_string +from .common import assert_parser_with_string, CharacterNotAllowedException def _test_json_schema_parsing_with_string(string: str, schema_dict: dict, expect_success: bool): @@ -210,4 +210,15 @@ def test_string_escaping(): # Unicode digit outside of hex range test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9fP world"}}' _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), False) + + +def test_comma_after_all_object_keys_fails(): + class SomeSchema(BaseModel): + key: str + + test_string = '{"key": "val",' + with pytest.raises(CharacterNotAllowedException): + _test_json_schema_parsing_with_string(test_string, SomeSchema.schema(), True) + + \ No newline at end of file