Skip to content

Commit

Permalink
JsonSchemaParser list parser supports min / max items
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Oct 31, 2023
1 parent 923e096 commit 1547797
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
30 changes: 26 additions & 4 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def get_parser(
elif value_schema.type == "array":
if value_schema.items is None:
raise LMFormatEnforcerException(f"List '{value_schema.title}' has no member type. Hint: If this is from a Pydantic Schema, use List[AAA] instead of list")
return ListParsingState(parsing_state, ending_characters, value_schema.items)
return ListParsingState(parsing_state, ending_characters, value_schema.items, value_schema.minItems, value_schema.maxItems)
else:
raise Exception("Unsupported type " + str(value_schema.type))

Expand Down Expand Up @@ -435,48 +435,70 @@ class ListParsingState(PrimitiveParsingState):
list_member_type: JsonSchemaObject
seen_list_opener: bool = False
seen_list_closer: bool = False
num_items_seen: int = 0

def __init__(
self,
root: JsonSchemaParser,
ending_characters: str,
list_member_type: JsonSchemaObject,
min_items: Optional[int],
max_items: Optional[int],
):
super().__init__(root, ending_characters)
self.list_member_type = list_member_type
self.min_items = min_items
self.max_items = max_items

def add_character(self, new_character: str):
if self.seen_list_closer:
super().add_character(new_character)
if new_character == "[":
# TODO: We currently don't support empty arrays, due to needing to allow both the close array bracket
# and the first character of the item at the same timestep, which is hard with the current design.
self.num_items_seen = 1
self.seen_list_opener = True
self.root.object_stack.append(
get_parser(
self.root,
self.list_member_type,
"],",
self.get_allowed_control_characters(),
)
)
elif new_character == "]":
self.seen_list_closer = True
elif new_character == ",":
if not self.seen_list_closer:
self.num_items_seen += 1

self.root.object_stack.append(
get_parser(
self.root,
self.list_member_type,
"],",
self.get_allowed_control_characters(),
)
)

def _get_allowed_primitive_characters(self) -> str:
if not self.seen_list_opener:
return "[" + WHITESPACE_CHARACTERS
elif not self.seen_list_closer:
return "]," + WHITESPACE_CHARACTERS
return self.get_allowed_control_characters() + WHITESPACE_CHARACTERS
else:
# The parent function will take care of allowing the ending tokens.
return ""

def can_end(self) -> bool:
return self.seen_list_closer

def get_allowed_control_characters(self):
ending_characters = ""
has_enough_items = self.min_items is None or self.num_items_seen >= self.min_items
can_add_another_item = self.max_items is None or self.num_items_seen < self.max_items

if can_add_another_item:
ending_characters += ","
if has_enough_items:
ending_characters += "]"
return ending_characters

17 changes: 15 additions & 2 deletions tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Dict, List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field, conlist
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
Expand Down Expand Up @@ -43,7 +43,7 @@ class SampleModel(BaseModel):
num: int
dec: Optional[float] = None
message: Optional[str] = None
list_of_strings: Optional[List[str]] = None
list_of_strings: Optional[List[str]] = Field(None, min_length=2, max_length=3)
inner_dict: Optional[Dict[str, InnerModel]] = None
simple_dict: Optional[Dict[str, int]] = None
list_of_models: Optional[List[InnerModel]] = None
Expand Down Expand Up @@ -149,3 +149,16 @@ class DictModel(BaseModel):
# list is not valid because we don't know what the member type is, expect our exception
with pytest.raises(LMFormatEnforcerException):
_test_json_schema_parsing_with_string('{"num":1,"l":[1,2,3]}', DictModel.schema(), False)


def test_list_length_limitations():
# list_of_strings is defined as having a min length of 2 and a max length of 3
one_string = '{"num":1,"list_of_strings":["a"]}'
_test_json_schema_parsing_with_string(one_string, SampleModel.schema(), False)
two_strings = '{"num":1,"list_of_strings":["a", "b"]}'
_test_json_schema_parsing_with_string(two_strings, SampleModel.schema(), True)
three_strings = '{"num":1,"list_of_strings":["a","b","c"]}'
_test_json_schema_parsing_with_string(three_strings, SampleModel.schema(), True)
four_strings = '{"num":1,"list_of_strings":["a","b","c","d"]}'
_test_json_schema_parsing_with_string(four_strings, SampleModel.schema(), False)

0 comments on commit 1547797

Please sign in to comment.