diff --git a/lmformatenforcer/tokenizerprefixtree.py b/lmformatenforcer/tokenizerprefixtree.py index fac9a70..a56b30e 100644 --- a/lmformatenforcer/tokenizerprefixtree.py +++ b/lmformatenforcer/tokenizerprefixtree.py @@ -1,5 +1,5 @@ from typing import Dict, List, Tuple - +import json class TokenizerPrefixTreeNode: def __init__(self): @@ -19,6 +19,13 @@ def __init__(self, regular_tokens: List[Tuple[int, str]]): has_newline = "\n" in decoded or "\r" in decoded if not (has_quote_before_end or has_newline): + if '\\' in decoded[:-1]: + # If there is a backslash that is not trailing, we might be in an illegal json territory. Need to verify + # that is is a legal json character streak + try: + json.loads(f'"{decoded}"') + except json.decoder.JSONDecodeError: + continue self.json_freetext_tokens.append(token_idx) def _add_token_to_tree(self, token_str: str, token_idx: int, node: TokenizerPrefixTreeNode): diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index f351d69..e606444 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -213,7 +213,16 @@ class SomeSchema(BaseModel): test_string = '{"key": "val",' with pytest.raises(CharacterNotAllowedException): _test_json_schema_parsing_with_string(test_string, SomeSchema.schema(), True) - + + +def test_single_quote_must_not_be_escaped(): + class SomeSchema(BaseModel): + key: str + + test_string = '{"key": "I\\\'m a string"}' + with pytest.raises(CharacterNotAllowedException): + _test_json_schema_parsing_with_string(test_string, SomeSchema.schema(), True) + def test_string_length_limitation(): class SomeSchema(BaseModel):