Skip to content

Commit

Permalink
Adding default max json array length, preventing LLM infinite loops i…
Browse files Browse the repository at this point in the history
…n some cases
  • Loading branch information
noamgat committed Jul 15, 2024
1 parent f1dd75b commit d3dc0ad
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ There are several environment variables that can be set, that affect the operati

- `LMFE_MAX_CONSECUTIVE_WHITESPACES` - How many consecutive whitespaces are allowed when parsing JsonSchemaObjects. Default: 12.
- `LMFE_STRICT_JSON_FIELD_ORDER` - Should the JsonSchemaParser force the properties to appear in the same order as they appear in the 'required' list of the JsonSchema? (Note: this is consistent with the order of declaration in Pydantic models). Default: False.
- `LMFE_MAX_JSON_ARRAY_LENGTH` - What is the maximal JSON array length, if not specified by the schema. Helps LLM Avoid infinite loops. Default: 20.

### Option 2: via the CharacterLevelParserConfig class
When using the library through code, any `CharacterLevelParser` (`JsonSchemaParser`, `RegexParser` etc) constructor receives an optional `CharacterLevelParserConfig` object.
Expand Down
7 changes: 6 additions & 1 deletion lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Hashable, List, Optional, TypeVar
from .consts import (COMPLETE_ALPHABET, WHITESPACE_CHARACTERS, DEFAULT_MAX_CONSECUTIVE_WHITESPACES,
DEFAULT_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES,
CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER)
CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH,
DEFAULT_MAX_JSON_ARRAY_LENGTH)


def _parse_bool(s: str) -> bool:
Expand All @@ -29,6 +30,10 @@ class CharacterLevelParserConfig:
DEFAULT_FORCE_JSON_FIELD_ORDER)
"""Whether the JsonSchemaParser will force fields to appear in the
order of the 'required' field in the schema"""
max_json_array_length: int = _env_or_default_field(CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH,
DEFAULT_MAX_JSON_ARRAY_LENGTH)
"""What is the maximum json array length if not specified by the schema. Helps the LLM
avoid infinite loops."""


class CharacterLevelParser(abc.ABC):
Expand Down
7 changes: 6 additions & 1 deletion lmformatenforcer/consts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
COMPLETE_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+-=[]{};:,./<>? `'\""
DEFAULT_MAX_CONSECUTIVE_WHITESPACES = 12
DEFAULT_FORCE_JSON_FIELD_ORDER = False
DEFAULT_MAX_JSON_ARRAY_LENGTH = 20
WHITESPACE_CHARACTERS = " \t\n\r"
BACKSLASH = "\\"
BACKSLASH_ESCAPING_CHARACTERS = '"\\/bfnrt' # Characters allowed after an escaping backslash, except unicode
Expand All @@ -12,4 +13,8 @@

CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER = 'LMFE_STRICT_JSON_FIELD_ORDER'
"""Environment variable for externally controlling whether the JsonSchemaParser will force
fields to appear in the order of the 'required' field in the schema. Default: false"""
fields to appear in the order of the 'required' field in the schema. Default: false"""

CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH = 'LMFE_MAX_JSON_ARRAY_LENGTH'
"""Environment variable for externally controlling what is the maximal JSON array length,
if not specified by the schema. Default: 20"""
3 changes: 3 additions & 0 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,9 @@ def __init__(
self.list_member_type = list_member_type
self.min_items = min_items
self.max_items = max_items
default_max = root.config.max_json_array_length
if self.max_items is None and default_max > 0 and (min_items is None or min_items < default_max):
self.max_items = default_max

def _clone(self) -> PrimitiveParsingState:
new = ListParsingState(self.root, self.list_member_type, self.min_items, self.max_items)
Expand Down
41 changes: 40 additions & 1 deletion tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH

from .common import assert_parser_with_string, CharacterNotAllowedException

Expand Down Expand Up @@ -634,3 +634,42 @@ def test_max_whitespaces_via_env_var():
for num_spaces in range(12):
expect_success = num_spaces <= 8
_test_json_schema_parsing_with_string(base_answer.replace("$", " " * num_spaces), schema, expect_success)


def test_max_json_array_length_via_env_var():
env_var_name = CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH

class IntListModel(BaseModel):
nums: List[int]

schema = IntListModel.model_json_schema()
with _temp_replace_env_var(env_var_name, '8'):
for num_numbers in range(12):
instance = IntListModel(nums=list(range(num_numbers)))
instance_str = instance.model_dump_json()
expect_success = num_numbers <= 8
_test_json_schema_parsing_with_string(instance_str, schema, expect_success)


def test_top_level_object_inheritance():
schema = {
"$defs": {
"ParentObject": {
"properties": {
"child": {
"type": "string"
}
},
"type": "object"
}
},
"properties": {
"parent": {
"$ref": "#/$defs/ParentObject"
}
},
"type": "object"
}
valid_object = '{"parent": {"child": "test"}}'
_test_json_schema_parsing_with_string(valid_object, schema, True)

0 comments on commit d3dc0ad

Please sign in to comment.