diff --git a/lmformatenforcer/__init__.py b/lmformatenforcer/__init__.py index 85510e7..5d3bfce 100644 --- a/lmformatenforcer/__init__.py +++ b/lmformatenforcer/__init__.py @@ -1,12 +1,14 @@ __all__ = ['CharacterLevelParser', 'StringParser', 'RegexParser', + 'UnionParser', + 'SequenceParser', 'JsonSchemaParser', 'TokenEnforcer', 'LMFormatEnforcerException', 'FormatEnforcerAnalyzer',] -from .characterlevelparser import CharacterLevelParser, StringParser +from .characterlevelparser import CharacterLevelParser, StringParser, UnionParser, SequenceParser from .regexparser import RegexParser from .jsonschemaparser import JsonSchemaParser from .tokenenforcer import TokenEnforcer diff --git a/lmformatenforcer/characterlevelparser.py b/lmformatenforcer/characterlevelparser.py index ccc2d53..939d8be 100644 --- a/lmformatenforcer/characterlevelparser.py +++ b/lmformatenforcer/characterlevelparser.py @@ -1,5 +1,5 @@ import abc -from typing import Hashable, Optional +from typing import Hashable, List, Optional class CharacterLevelParser(abc.ABC): @@ -11,7 +11,7 @@ def add_character(self, new_character: str) -> 'CharacterLevelParser': raise NotImplementedError() @abc.abstractmethod - def get_allowed_characters(self) ->str: + def get_allowed_characters(self) -> str: """Return a string containing all characters that are allowed at the current point in the parsing process.""" raise NotImplementedError() @@ -56,3 +56,77 @@ def get_allowed_characters(self) -> str: return "" def can_end(self) -> bool: return True + + +class UnionParser(CharacterLevelParser): + """A parser that allows a string that would be allowed by any of several different parsers""" + def __init__(self, parsers: List[CharacterLevelParser]): + self.parsers = parsers + + def add_character(self, new_character: str) -> CharacterLevelParser: + # This is a bit of a performance hit, as it means get_allowed_characters() is called twice. + relevant_parsers = [parser for parser in self.parsers if new_character in parser.get_allowed_characters()] + next_parsers = [parser.add_character(new_character) for parser in relevant_parsers] + if len(next_parsers) == 1: + return next_parsers[0] + return UnionParser(next_parsers) + + def get_allowed_characters(self) -> str: + allowed = "".join([parser.get_allowed_characters() for parser in self.parsers]) + return "".join(set(allowed)) + + def can_end(self) -> bool: + return any([parser.can_end() for parser in self.parsers]) + + def shortcut_key(self) -> str | None: + return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None + + def cache_key(self) -> Optional[Hashable]: + all_cache_keys = tuple(parser.cache_key() for parser in self.parsers) + if all(key is not None for key in all_cache_keys): + return ('union', all_cache_keys) + return None + + +class SequenceParser(CharacterLevelParser): + """A parser that is a sequence of multiple parsers.""" + def __init__(self, parsers: List[CharacterLevelParser]): + self.parsers = parsers + + def add_character(self, new_character: str) -> CharacterLevelParser: + legal_parsers = [] + # Tricky edge case: if the first parser can both end and accept the character, + # and the second parser can also accept, we don't know which scenario we are dealing + # with, so we need to return a UnionParser. + for idx, parser in enumerate(self.parsers): + if new_character in parser.get_allowed_characters(): + updated_parser = parser.add_character(new_character) + next_parsers = [updated_parser] + self.parsers[idx+1:] + legal_parsers.append(SequenceParser(next_parsers)) + if not parser.can_end(): + break + if len(legal_parsers) == 1: + return legal_parsers[0] + return UnionParser(legal_parsers) + + def get_allowed_characters(self) -> str: + allowed_character_strs = [] + for parser in self.parsers: + allowed_character_strs.append(parser.get_allowed_characters()) + if not parser.can_end(): + break + return "".join([parser.get_allowed_characters() for parser in self.parsers]) + + def can_end(self) -> bool: + return all([parser.can_end() for parser in self.parsers]) + + def shortcut_key(self) -> Optional[str]: + return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None + + def cache_key(self) -> Optional[Hashable]: + all_cache_keys = tuple(parser.cache_key() for parser in self.parsers) + if all(key is not None for key in all_cache_keys): + return ('sequence', all_cache_keys) + return None + + diff --git a/tests/test_composite_parsers.py b/tests/test_composite_parsers.py new file mode 100644 index 0000000..2bfc9f1 --- /dev/null +++ b/tests/test_composite_parsers.py @@ -0,0 +1,31 @@ +from lmformatenforcer import UnionParser, SequenceParser, StringParser +from .common import assert_parser_with_string +from pydantic import BaseModel +from lmformatenforcer import JsonSchemaParser + + +def test_string_choice(): + parser = UnionParser([StringParser('aa'), StringParser('bb')]) + assert_parser_with_string('aa', parser, True) + assert_parser_with_string('bb', parser, True) + assert_parser_with_string('ab', parser, False) + assert_parser_with_string('aabb', parser, False) + + +def test_string_sequence(): + parser = SequenceParser([StringParser('aa'), StringParser('bb')]) + assert_parser_with_string('aa', parser, False) + assert_parser_with_string('bb', parser, False) + assert_parser_with_string('ab', parser, False) + assert_parser_with_string('aabb', parser, True) + assert_parser_with_string('bbaa', parser, False) + + +def test_json_markdown_sequence(): + class TestModel(BaseModel): + a: str + json_parser = JsonSchemaParser(TestModel.schema()) + parser = SequenceParser([StringParser("```json\n"), json_parser, StringParser('\n```')]) + assert_parser_with_string('```json\n{"a": "b"}\n```', parser, True) + assert_parser_with_string('{"a": "b"}', parser, False) + diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index 160b513..f1f824a 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -1,6 +1,6 @@ import json from typing import Dict, List, Optional -from pydantic import BaseModel, Field, conlist +from pydantic import BaseModel, Field from lmformatenforcer import JsonSchemaParser from enum import Enum import pytest