Skip to content

Commit

Permalink
Added support for string length limitation
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Nov 15, 2023
1 parent fc99de1 commit 8490345
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
19 changes: 18 additions & 1 deletion lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class _Context:
# We store the active parser in the context, so that if a node adds to the stack, it knows
# to which parser's stack to add.
active_parser: "JsonSchemaParser"
alphabet_without_quotes: str

object_stack: List[CharacterLevelParser]
context: _Context
Expand All @@ -34,6 +35,7 @@ def __init__(self,
self.context = JsonSchemaParser._Context()
self.context.model_class = JsonSchemaObject(**json_schema)
self.context.active_parser = self
self.context.alphabet_without_quotes = self.config.alphabet.replace('"', '')

self.num_consecutive_whitespaces = num_consecutive_whitespaces
if existing_stack is None:
Expand Down Expand Up @@ -93,7 +95,8 @@ def shortcut_key(self) -> Optional[str]:
if self.object_stack:
current_parser = self.object_stack[-1]
if isinstance(current_parser, StringParsingState):
if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote:
if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote \
and current_parser.min_length is None and current_parser.max_length is None:
# Performance optimization: When we are parsing a string that is not from a list of allowed strings, most tokens
# are legal. The exploration can be more costly than the LM itself for large tokenizers (because this is pure python),
# so we signal that we are in a "freetext" mode, and reuse the allowed token list throughout the run.
Expand All @@ -120,6 +123,8 @@ def get_parser(
parsing_state,
value_schema.enum,
require_opening_quote=True,
min_length=value_schema.minLength,
max_length=value_schema.maxLength,
)
elif value_schema.type == "object":
return ObjectParsingState(value_schema, parsing_state)
Expand Down Expand Up @@ -385,27 +390,35 @@ class StringParsingState(PrimitiveParsingState):
parsed_string: str
seen_closing_quote: bool
seen_opening_quote: bool
min_length: Optional[int]
max_length: Optional[int]

def __init__(
self,
root: JsonSchemaParser,
allowed_strings: List[str],
require_opening_quote: bool,
require_closing_quote: bool = True,
min_length: Optional[int]=None,
max_length: Optional[int]=None,
):
super().__init__(root)
self.allowed_strings = allowed_strings
self.seen_closing_quote = False
self.seen_opening_quote = not require_opening_quote
self.require_closing_quote = require_closing_quote
self.require_opening_quote = require_opening_quote
self.min_length = min_length
self.max_length = max_length

def _clone(self) -> "StringParsingState":
clone = StringParsingState(
self.root,
self.allowed_strings,
self.require_opening_quote,
self.require_closing_quote,
self.min_length,
self.max_length
)
clone.parsed_string = self.parsed_string
clone.seen_closing_quote = self.seen_closing_quote
Expand Down Expand Up @@ -452,6 +465,10 @@ def get_allowed_characters(self) -> str:
allowed_next_characters.extend(WHITESPACE_CHARACTERS)
return "".join(allowed_next_characters)
else:
if self.min_length is not None and len(self.parsed_string) < self.min_length:
return self.root.context.alphabet_without_quotes + BACKSLASH
if self.max_length is not None and len(self.parsed_string) >= self.max_length:
return '"'
return self.root.config.alphabet + BACKSLASH

def can_end(self) -> bool:
Expand Down
15 changes: 12 additions & 3 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from typing import Annotated, Dict, List, Optional
from pydantic import BaseModel, Field, StringConstraints
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
Expand Down Expand Up @@ -220,5 +220,14 @@ class SomeSchema(BaseModel):
with pytest.raises(CharacterNotAllowedException):
_test_json_schema_parsing_with_string(test_string, SomeSchema.schema(), True)



def test_string_length_limitation():
class SomeSchema(BaseModel):
key: Annotated[str, StringConstraints(min_length=2, max_length=3)]

for str_length in range(10):
test_string = f'{{"key": "{str_length * "a"}"}}'
expect_sucess = 2 <= str_length <= 3
_test_json_schema_parsing_with_string(test_string, SomeSchema.schema(), expect_sucess)


0 comments on commit 8490345

Please sign in to comment.