Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new parser: MultiChoicesParser #133

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lmformatenforcer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
'TokenEnforcer',
'TokenEnforcerTokenizerData',
'LMFormatEnforcerException',
'FormatEnforcerAnalyzer',]
'FormatEnforcerAnalyzer',
'MultiChoicesParser']

from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig, StringParser, UnionParser, SequenceParser
from .multichoicesparser import MultiChoicesParser
from .regexparser import RegexParser
from .jsonschemaparser import JsonSchemaParser
from .tokenenforcer import TokenEnforcer, TokenEnforcerTokenizerData
Expand Down
4 changes: 1 addition & 3 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import os
from dataclasses import dataclass, field
from typing import Any, Hashable, List, Optional, TypeVar
from typing import Any, Hashable, List, Optional
from .consts import (COMPLETE_ALPHABET, CONFIG_ENV_VAR_DEFAULT_ALPHABET, WHITESPACE_CHARACTERS, DEFAULT_MAX_CONSECUTIVE_WHITESPACES,
DEFAULT_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES,
CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH,
Expand Down Expand Up @@ -184,5 +184,3 @@ def cache_key(self) -> Optional[Hashable]:
if all(key is not None for key in all_cache_keys):
return ('sequence', all_cache_keys)
return None


63 changes: 63 additions & 0 deletions lmformatenforcer/multichoicesparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations
from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig
from multi_choices_parser import MultiChoicesParser as MCP

class MultiChoicesParser(CharacterLevelParser):
def __init__(self, list_of_choices : list[list[str]] | None):
"""
NOTE: This parser is based on the multi-choices-parser package: https://github.com/HichemAK/multi-choices-parser.

A efficient incremental parser for multi-choice grammars. They are defined as grammars of the form:

start: list1 list2 ... listn

list1: choice1_1 | choice1_2 | ... | choice1_k1

list2: choice2_1 | choice2_2 | ... | choice2_k2

...

listn: choicen_1 | choicen_2 | ... | choicen_km

where choicex_y is a sequence of integers and can possibly be empty

Example:
start: det noun

det: "the " | "an " | "a " | ""

noun: "orange" | "apple" | "banana"

This was particularly optimized when the size of the lists of choices is
very large (up to order of millions), which can be helpful
to represent entities preceeded (or not) by a determinent.
For example, in Wikipedia, there are around 7 million entities.
"""
if list_of_choices is not None:
self.parser = MCP(list_of_choices)
config = CharacterLevelParserConfig()
else:
self.parser = None
config = None

super().__init__(config)

def add_character(self, new_character: str) -> MultiChoicesParser:
copy = self.copy()
copy.parser.step(new_character)
return copy

def get_allowed_characters(self) -> str:
return ''.join(x for x in self.parser.next() if x is not self.parser.end_symb)

def can_end(self) -> bool:
return self.parser.end_symb in self.parser.next() or self.parser.finished

def cache_key(self):
return hash((id(self.parser.current_state), self.parser.finished, self.parser.success))

def copy(self) -> MultiChoicesParser:
copy = MultiChoicesParser(None)
copy.parser = self.parser.copy()
copy.config = self.config
return copy
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ packages = [
"Bug Tracker" = "https://github.com/noamgat/lm-format-enforcer/issues"

[tool.poetry.dependencies]
python = ">=3.8,<4.0"
python = ">=3.8.1,<4.0"
pydantic = ">=1.10.8"
interegular = ">=0.3.2"
packaging = "*"
pyyaml = "*"
multi-choices-parser = ">=0.10.0"

[tool.poetry.group.dev.dependencies]
mock = "5.1.0"
Expand Down
138 changes: 138 additions & 0 deletions tests/test_multichoicesparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Optional
from lmformatenforcer import MultiChoicesParser, CharacterLevelParserConfig
from lmformatenforcer.consts import COMPLETE_ALPHABET
from .common import assert_parser_with_string

def _test_mcp_parsing_with_string(string: str, list_of_choices: list[list[str]], expect_success: bool, custom_alphabet: Optional[str] = None, profile_file_path: Optional[str] = None):
parser = MultiChoicesParser(list_of_choices)
if custom_alphabet:
parser.config = CharacterLevelParserConfig(alphabet=custom_alphabet)
assert_parser_with_string(string, parser, expect_success, profile_file_path=profile_file_path)


def test_parsing_exact_string():
_test_mcp_parsing_with_string(
string='abc123',
list_of_choices=[['abc123']],
expect_success=True)

def test_parsing_exact_string_failure():
_test_mcp_parsing_with_string(
string='abc124',
list_of_choices=[['abc123']],
expect_success=False)

def test_parsing_exact_string_not_reaching_end():
_test_mcp_parsing_with_string(
string='abc123',
list_of_choices=[['abc1234']],
expect_success=False)


def test_parsing_letter_options():
for letter in 'cdefghif':
expect_success = letter in 'cdef'
_test_mcp_parsing_with_string(
string=f'ab{letter}123',
list_of_choices=[
['ab'],
list('cdef'),
['123']
],
expect_success=expect_success)


def test_parsing_digits():
for letter in '0123abcd':
expect_success = letter.isnumeric()
_test_mcp_parsing_with_string(
string=f'ab{letter}123',
list_of_choices=[
['ab'],
list('0123456789'),
['123']
],
expect_success=expect_success)


def test_parsing_repeat():
for num_repeats in range(20):
expect_success = num_repeats > 0
_test_mcp_parsing_with_string(
string=f'ab{"c" * num_repeats}123',
list_of_choices=[
['ab'],
*([['c']] + [["c", ""]]*(num_repeats-1)),
['123']
],
expect_success=expect_success)


def test_any_character():
chars = list('0123456789abcdefghij') + ['']
for num_repeats, character in enumerate(chars[:-1]):
expect_success = num_repeats > 0
_test_mcp_parsing_with_string(
string=f'ab{character * num_repeats}123',
list_of_choices=[
['ab'],
*([chars[:-1]] + [chars]*(num_repeats-1)),
['123']
],
expect_success=expect_success,
#profile_file_path=f'RegexAny{num_repeats}.prof')
profile_file_path=None)


def test_dates():
# https://stackoverflow.com/q/15491894 , removed the ^ and $ because interegular doesn't support them
date_lcs = [
[str(x).zfill(2) for x in range(1,32)],
["/"],
[str(x).zfill(2) for x in range(1,13)],
["/"],
[str(x).zfill(4) for x in range(3000)],
]
_test_mcp_parsing_with_string('01/01/2020', date_lcs, True)
_test_mcp_parsing_with_string('29/04/1986', date_lcs, True)
_test_mcp_parsing_with_string('001/01/2020', date_lcs, False)


def test_string_choice():
lcs = [
['abc', 'def', 'ghi']
]
_test_mcp_parsing_with_string('abc', lcs, True)
_test_mcp_parsing_with_string('def', lcs, True)
_test_mcp_parsing_with_string('ghi', lcs, True)
_test_mcp_parsing_with_string('aei', lcs, False)


def test_increasing_alphabet():
alph = COMPLETE_ALPHABET
any_3chars = [alph, alph, alph]
_test_mcp_parsing_with_string('abc', any_3chars, True)
_test_mcp_parsing_with_string('abΣ', any_3chars, False)
custom_alphabet = COMPLETE_ALPHABET + 'Σ'
any_3chars = [custom_alphabet, custom_alphabet, custom_alphabet]
_test_mcp_parsing_with_string('abΣ', any_3chars, True, custom_alphabet=custom_alphabet)

def test_phone_number():
phone_lcs = [
["("],
[str(x).zfill(3) for x in range(1000)],
[")"],
[str(x).zfill(3) for x in range(1000)],
['-'],
[str(x).zfill(4) for x in range(10000)]
]
_test_mcp_parsing_with_string('(312)011-2444', phone_lcs, True)
_test_mcp_parsing_with_string('312-011-2444', phone_lcs, False)

# def test_negative_matching():
# # https://github.com/noamgat/lm-format-enforcer/issues/70
# pattern = r'- Keywords: [^;:,/\n\r]+; [^;:,/\n\r]+; [^;:,/\n\r]+; [^;:,/\n\r]+; [^;:,/\n\r]+'
# text = '- Keywords: intranasal vaccine, long-lasting immunity, mucosal antibody response, T cells, adjuvants'
# _test_regex_parsing_with_string(text, pattern, False)
# correct_text = text.replace(',', ';')
# _test_regex_parsing_with_string(correct_text, pattern, True)