From d3dc0ad3868af17f9d83c6b0557467574c14a61a Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Mon, 15 Jul 2024 18:37:46 +0300 Subject: [PATCH] Adding default max json array length, preventing LLM infinite loops in some cases --- README.md | 1 + lmformatenforcer/characterlevelparser.py | 7 +++- lmformatenforcer/consts.py | 7 +++- lmformatenforcer/jsonschemaparser.py | 3 ++ tests/test_jsonschemaparser.py | 41 +++++++++++++++++++++++- 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 47b49b4..da5b84b 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,7 @@ There are several environment variables that can be set, that affect the operati - `LMFE_MAX_CONSECUTIVE_WHITESPACES` - How many consecutive whitespaces are allowed when parsing JsonSchemaObjects. Default: 12. - `LMFE_STRICT_JSON_FIELD_ORDER` - Should the JsonSchemaParser force the properties to appear in the same order as they appear in the 'required' list of the JsonSchema? (Note: this is consistent with the order of declaration in Pydantic models). Default: False. +- `LMFE_MAX_JSON_ARRAY_LENGTH` - What is the maximal JSON array length, if not specified by the schema. Helps LLM Avoid infinite loops. Default: 20. ### Option 2: via the CharacterLevelParserConfig class When using the library through code, any `CharacterLevelParser` (`JsonSchemaParser`, `RegexParser` etc) constructor receives an optional `CharacterLevelParserConfig` object. diff --git a/lmformatenforcer/characterlevelparser.py b/lmformatenforcer/characterlevelparser.py index cd9aee0..55c03e7 100644 --- a/lmformatenforcer/characterlevelparser.py +++ b/lmformatenforcer/characterlevelparser.py @@ -4,7 +4,8 @@ from typing import Hashable, List, Optional, TypeVar from .consts import (COMPLETE_ALPHABET, WHITESPACE_CHARACTERS, DEFAULT_MAX_CONSECUTIVE_WHITESPACES, DEFAULT_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, - CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER) + CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH, + DEFAULT_MAX_JSON_ARRAY_LENGTH) def _parse_bool(s: str) -> bool: @@ -29,6 +30,10 @@ class CharacterLevelParserConfig: DEFAULT_FORCE_JSON_FIELD_ORDER) """Whether the JsonSchemaParser will force fields to appear in the order of the 'required' field in the schema""" + max_json_array_length: int = _env_or_default_field(CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH, + DEFAULT_MAX_JSON_ARRAY_LENGTH) + """What is the maximum json array length if not specified by the schema. Helps the LLM + avoid infinite loops.""" class CharacterLevelParser(abc.ABC): diff --git a/lmformatenforcer/consts.py b/lmformatenforcer/consts.py index b3e89ea..620ad73 100644 --- a/lmformatenforcer/consts.py +++ b/lmformatenforcer/consts.py @@ -1,6 +1,7 @@ COMPLETE_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+-=[]{};:,./<>? `'\"" DEFAULT_MAX_CONSECUTIVE_WHITESPACES = 12 DEFAULT_FORCE_JSON_FIELD_ORDER = False +DEFAULT_MAX_JSON_ARRAY_LENGTH = 20 WHITESPACE_CHARACTERS = " \t\n\r" BACKSLASH = "\\" BACKSLASH_ESCAPING_CHARACTERS = '"\\/bfnrt' # Characters allowed after an escaping backslash, except unicode @@ -12,4 +13,8 @@ CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER = 'LMFE_STRICT_JSON_FIELD_ORDER' """Environment variable for externally controlling whether the JsonSchemaParser will force -fields to appear in the order of the 'required' field in the schema. Default: false""" \ No newline at end of file +fields to appear in the order of the 'required' field in the schema. Default: false""" + +CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH = 'LMFE_MAX_JSON_ARRAY_LENGTH' +"""Environment variable for externally controlling what is the maximal JSON array length, +if not specified by the schema. Default: 20""" diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index e1307fe..86df79f 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -611,6 +611,9 @@ def __init__( self.list_member_type = list_member_type self.min_items = min_items self.max_items = max_items + default_max = root.config.max_json_array_length + if self.max_items is None and default_max > 0 and (min_items is None or min_items < default_max): + self.max_items = default_max def _clone(self) -> PrimitiveParsingState: new = ListParsingState(self.root, self.list_member_type, self.min_items, self.max_items) diff --git a/tests/test_jsonschemaparser.py b/tests/test_jsonschemaparser.py index c8ef6d0..16321d0 100644 --- a/tests/test_jsonschemaparser.py +++ b/tests/test_jsonschemaparser.py @@ -6,7 +6,7 @@ from lmformatenforcer import JsonSchemaParser from enum import Enum import pytest -from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES +from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_STRICT_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES, CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH from .common import assert_parser_with_string, CharacterNotAllowedException @@ -634,3 +634,42 @@ def test_max_whitespaces_via_env_var(): for num_spaces in range(12): expect_success = num_spaces <= 8 _test_json_schema_parsing_with_string(base_answer.replace("$", " " * num_spaces), schema, expect_success) + + +def test_max_json_array_length_via_env_var(): + env_var_name = CONFIG_ENV_VAR_MAX_JSON_ARRAY_LENGTH + + class IntListModel(BaseModel): + nums: List[int] + + schema = IntListModel.model_json_schema() + with _temp_replace_env_var(env_var_name, '8'): + for num_numbers in range(12): + instance = IntListModel(nums=list(range(num_numbers))) + instance_str = instance.model_dump_json() + expect_success = num_numbers <= 8 + _test_json_schema_parsing_with_string(instance_str, schema, expect_success) + + +def test_top_level_object_inheritance(): + schema = { + "$defs": { + "ParentObject": { + "properties": { + "child": { + "type": "string" + } + }, + "type": "object" + } + }, + "properties": { + "parent": { + "$ref": "#/$defs/ParentObject" + } + }, + "type": "object" + } + valid_object = '{"parent": {"child": "test"}}' + _test_json_schema_parsing_with_string(valid_object, schema, True) +