diff --git a/lmformatenforcer/consts.py b/lmformatenforcer/consts.py index 300f27a..42c8824 100644 --- a/lmformatenforcer/consts.py +++ b/lmformatenforcer/consts.py @@ -1,3 +1,6 @@ COMPLETE_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+-=[]{};:,./<>? `'\"" MAX_CONSECUTIVE_WHITESPACES = 12 -WHITESPACE_CHARACTERS = " \t\n\r" \ No newline at end of file +WHITESPACE_CHARACTERS = " \t\n\r" +BACKSLASH = "\\" +BACKSLASH_ESCAPING_CHARACTERS = '"\\/bfnrt' # Characters allowed after an escaping backslash, except unicode +BACKSLACH_UNICODE_ESCAPE = "u" diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 1e4c3ba..4fbb81f 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -5,8 +5,8 @@ from .external.jsonschemaobject import JsonSchemaObject from .exceptions import LMFormatEnforcerException -from .characterlevelparser import CharacterLevelParser, ForceStopParser, UnionParser -from .consts import COMPLETE_ALPHABET, MAX_CONSECUTIVE_WHITESPACES, WHITESPACE_CHARACTERS +from .characterlevelparser import CharacterLevelParser, ForceStopParser, SequenceParser, StringParser, UnionParser +from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, COMPLETE_ALPHABET, MAX_CONSECUTIVE_WHITESPACES, WHITESPACE_CHARACTERS class JsonSchemaParser(CharacterLevelParser): @@ -418,6 +418,14 @@ def add_character(self, new_character: str): else: self.seen_closing_quote = True self.parsed_string = self.parsed_string[:-1] + if new_character == BACKSLASH: + # After a backslack we immediately have the escaping character, and if its 'u', we have 4 hex digits + escaping_character_parsers: List[CharacterLevelParser] = [StringParser(c) for c in BACKSLASH_ESCAPING_CHARACTERS] + hex_digit_parser: CharacterLevelParser = UnionParser([StringParser(c) for c in "0123456789abcdefABCDEF"]) + unicode_components: List[CharacterLevelParser] = list([StringParser("u")] + [hex_digit_parser] * 4) + unicode_escape_parser: CharacterLevelParser = SequenceParser(unicode_components) + json_escaping_parser = UnionParser(escaping_character_parsers + [unicode_escape_parser]) + self.root.context.active_parser.object_stack.append(json_escaping_parser) return self def get_allowed_characters(self) -> str: @@ -439,7 +447,7 @@ def get_allowed_characters(self) -> str: allowed_next_characters.extend(WHITESPACE_CHARACTERS) return "".join(allowed_next_characters) else: - return COMPLETE_ALPHABET + return COMPLETE_ALPHABET + BACKSLASH def can_end(self) -> bool: if self.require_closing_quote: diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index 62923e4..86973f0 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -4,6 +4,7 @@ from lmformatenforcer import JsonSchemaParser from enum import Enum import pytest +from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS from lmformatenforcer.exceptions import LMFormatEnforcerException @@ -188,3 +189,25 @@ class ListOfNoMinLengthModel(BaseModel): _test_json_schema_parsing_with_string(no_strings, ListOfNoMinLengthModel.schema(), True) _test_json_schema_parsing_with_string(one_string, ListOfNoMinLengthModel.schema(), True) _test_json_schema_parsing_with_string(two_strings, ListOfNoMinLengthModel.schema(), False) + + +def test_string_escaping(): + for escaping_character in BACKSLASH_ESCAPING_CHARACTERS: + test_string = f'{{"num":1,"message":"hello {BACKSLASH}{escaping_character} world"}}' + _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), True) + for non_escaping_character in 'a1?': + test_string = f'{{"num":1,"message":"hello {BACKSLASH}{non_escaping_character} world"}}' + _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), False) + + # Unicode + test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f0 world"}}' + _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), True) + + # Not enough unicode digits + test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9f world"}}' + _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), False) + + # Unicode digit outside of hex range + test_string = f'{{"num":1,"message":"hello {BACKSLASH}uf9fP world"}}' + _test_json_schema_parsing_with_string(test_string, SampleModel.schema(), False) + \ No newline at end of file