Skip to content

Commit

Permalink
Added SequenceParser and UnionParser to allow chaining of several par…
Browse files Browse the repository at this point in the history
…sers via certain rules
  • Loading branch information
noamgat committed Nov 4, 2023
1 parent fcc349b commit 7a67f5a
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 4 deletions.
4 changes: 3 additions & 1 deletion lmformatenforcer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
__all__ = ['CharacterLevelParser',
'StringParser',
'RegexParser',
'UnionParser',
'SequenceParser',
'JsonSchemaParser',
'TokenEnforcer',
'LMFormatEnforcerException',
'FormatEnforcerAnalyzer',]

from .characterlevelparser import CharacterLevelParser, StringParser
from .characterlevelparser import CharacterLevelParser, StringParser, UnionParser, SequenceParser
from .regexparser import RegexParser
from .jsonschemaparser import JsonSchemaParser
from .tokenenforcer import TokenEnforcer
Expand Down
78 changes: 76 additions & 2 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Hashable, Optional
from typing import Hashable, List, Optional


class CharacterLevelParser(abc.ABC):
Expand All @@ -11,7 +11,7 @@ def add_character(self, new_character: str) -> 'CharacterLevelParser':
raise NotImplementedError()

@abc.abstractmethod
def get_allowed_characters(self) ->str:
def get_allowed_characters(self) -> str:
"""Return a string containing all characters that are allowed at the current point in the parsing process."""
raise NotImplementedError()

Expand Down Expand Up @@ -56,3 +56,77 @@ def get_allowed_characters(self) -> str:
return ""
def can_end(self) -> bool:
return True


class UnionParser(CharacterLevelParser):
"""A parser that allows a string that would be allowed by any of several different parsers"""
def __init__(self, parsers: List[CharacterLevelParser]):
self.parsers = parsers

def add_character(self, new_character: str) -> CharacterLevelParser:
# This is a bit of a performance hit, as it means get_allowed_characters() is called twice.
relevant_parsers = [parser for parser in self.parsers if new_character in parser.get_allowed_characters()]
next_parsers = [parser.add_character(new_character) for parser in relevant_parsers]
if len(next_parsers) == 1:
return next_parsers[0]
return UnionParser(next_parsers)

def get_allowed_characters(self) -> str:
allowed = "".join([parser.get_allowed_characters() for parser in self.parsers])
return "".join(set(allowed))

def can_end(self) -> bool:
return any([parser.can_end() for parser in self.parsers])

def shortcut_key(self) -> str | None:
return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None

def cache_key(self) -> Optional[Hashable]:
all_cache_keys = tuple(parser.cache_key() for parser in self.parsers)
if all(key is not None for key in all_cache_keys):
return ('union', all_cache_keys)
return None


class SequenceParser(CharacterLevelParser):
"""A parser that is a sequence of multiple parsers."""
def __init__(self, parsers: List[CharacterLevelParser]):
self.parsers = parsers

def add_character(self, new_character: str) -> CharacterLevelParser:
legal_parsers = []
# Tricky edge case: if the first parser can both end and accept the character,
# and the second parser can also accept, we don't know which scenario we are dealing
# with, so we need to return a UnionParser.
for idx, parser in enumerate(self.parsers):
if new_character in parser.get_allowed_characters():
updated_parser = parser.add_character(new_character)
next_parsers = [updated_parser] + self.parsers[idx+1:]
legal_parsers.append(SequenceParser(next_parsers))
if not parser.can_end():
break
if len(legal_parsers) == 1:
return legal_parsers[0]
return UnionParser(legal_parsers)

def get_allowed_characters(self) -> str:
allowed_character_strs = []
for parser in self.parsers:
allowed_character_strs.append(parser.get_allowed_characters())
if not parser.can_end():
break
return "".join([parser.get_allowed_characters() for parser in self.parsers])

def can_end(self) -> bool:
return all([parser.can_end() for parser in self.parsers])

def shortcut_key(self) -> Optional[str]:
return self.parsers[0].shortcut_key() if len(self.parsers) == 1 else None

def cache_key(self) -> Optional[Hashable]:
all_cache_keys = tuple(parser.cache_key() for parser in self.parsers)
if all(key is not None for key in all_cache_keys):
return ('sequence', all_cache_keys)
return None


31 changes: 31 additions & 0 deletions tests/test_composite_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from lmformatenforcer import UnionParser, SequenceParser, StringParser
from .common import assert_parser_with_string
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser


def test_string_choice():
parser = UnionParser([StringParser('aa'), StringParser('bb')])
assert_parser_with_string('aa', parser, True)
assert_parser_with_string('bb', parser, True)
assert_parser_with_string('ab', parser, False)
assert_parser_with_string('aabb', parser, False)


def test_string_sequence():
parser = SequenceParser([StringParser('aa'), StringParser('bb')])
assert_parser_with_string('aa', parser, False)
assert_parser_with_string('bb', parser, False)
assert_parser_with_string('ab', parser, False)
assert_parser_with_string('aabb', parser, True)
assert_parser_with_string('bbaa', parser, False)


def test_json_markdown_sequence():
class TestModel(BaseModel):
a: str
json_parser = JsonSchemaParser(TestModel.schema())
parser = SequenceParser([StringParser("```json\n"), json_parser, StringParser('\n```')])
assert_parser_with_string('```json\n{"a": "b"}\n```', parser, True)
assert_parser_with_string('{"a": "b"}', parser, False)

2 changes: 1 addition & 1 deletion tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, conlist
from pydantic import BaseModel, Field
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
Expand Down

0 comments on commit 7a67f5a

Please sign in to comment.