Skip to content

Commit

Permalink
Add ability to config JsonSchemaParser heuristics via environment var…
Browse files Browse the repository at this point in the history
…iables / config objects (#97)

* Added configuration options that can be controlled via env vars - max consecutive whitespaces and force json order

* Added documentation on how to use env var / configuration objects

* More documentation

* Whitespace cleanup
  • Loading branch information
noamgat authored May 4, 2024
1 parent 9ef1c90 commit 4167131
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 18 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# LM Format Enforcer Changelog

## v0.10.1
- Allowing control of LM Format Enforcer's heuristics via env var / configuration objects. See the 'Configuration options' section of the README.

## v0.9.10
- [#95] Added anyOf support to JsonSchemaParser, making function calls possible.

Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ idx | generated_token | generated_token_idx | generated_score | leading_token |
You can see that the model "wanted" to start the answer using ```Sure```, but the format enforcer forced it to use ```Michael``` - there was a big gap in token 1. Afterwards, almost all of the leading scores are all within the allowed token set, meaning the model likely did not hallucinate due to the token forcing. The only exception was timestep 4 - " Born" was forced while the LLM wanted to choose "born". This is a hint for the prompt engineer, to change the prompt to use a lowercase b instead.


## Configuration options

LM Format Enforcer makes use of several heuristics to avoid edge cases that may happen with LLM's generating structure outputs.
There are two ways to control these heuristics:

### Option 1: via Environment Variables

There are several environment variables that can be set, that affect the operation of the library. This method is useful when you don't want to modify the code, for example when using the library through the vLLM OpenAI server.

- `LMFE_MAX_CONSECUTIVE_WHITESPACES` - How many consecutive whitespaces are allowed when parsing JsonSchemaObjects. Default: 12.
- `LMFE_FORCE_JSON_FIELD_ORDER` - Should the JsonSchemaParser force the properties to appear in the same order as they appear in the 'required' list of the JsonSchema? (Note: this is consistent with the order of declaration in Pydantic models). Default: False.

### Option 2: via the CharacterLevelParserConfig class
When using the library through code, any `CharacterLevelParser` (`JsonSchemaParser`, `RegexParser` etc) constructor receives an optional `CharacterLevelParserConfig` object.

Therefore, to configure the heuristics of a single parser, instantiate a `CharacterLevelParserConfig` object, modify its values and pass it to the `CharacterLevelParser`'s constructor.



## Known issues and limitations

- LM Format Enforcer requires a python API to process the output logits of the language model. This means that until the APIs are extended, it can not be used with OpenAI ChatGPT and similar API based solutions.
Expand Down
30 changes: 26 additions & 4 deletions lmformatenforcer/characterlevelparser.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import abc
from dataclasses import dataclass
from typing import Hashable, List, Optional
from .consts import COMPLETE_ALPHABET, WHITESPACE_CHARACTERS
import os
from dataclasses import dataclass, field
from typing import Hashable, List, Optional, TypeVar
from .consts import (COMPLETE_ALPHABET, WHITESPACE_CHARACTERS, DEFAULT_MAX_CONSECUTIVE_WHITESPACES,
DEFAULT_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES,
CONFIG_ENV_VAR_LMFE_FORCE_JSON_FIELD_ORDER)


def _parse_bool(s: str) -> bool:
return s and (s.strip().lower() in ['true', '1'])


def _env_or_default_field(env_var: str, default_val):
default_val_type = type(default_val)
parser_func = _parse_bool if default_val_type == bool else default_val_type
def factory_func():
return parser_func(os.environ.get(env_var, str(default_val)))
return field(default_factory=factory_func)


@dataclass
class CharacterLevelParserConfig:
alphabet: str = COMPLETE_ALPHABET
alphabet: str = COMPLETE_ALPHABET
max_consecutive_whitespaces: int = _env_or_default_field(CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES,
DEFAULT_MAX_CONSECUTIVE_WHITESPACES)
"""How many consective whitespaces the JsonSchemaParser will allow"""
force_json_field_order: bool = _env_or_default_field(CONFIG_ENV_VAR_LMFE_FORCE_JSON_FIELD_ORDER,
DEFAULT_FORCE_JSON_FIELD_ORDER)
"""Whether the JsonSchemaParser will force fields to appear in the
order of the 'required' field in the schema"""


class CharacterLevelParser(abc.ABC):
Expand Down
11 changes: 10 additions & 1 deletion lmformatenforcer/consts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
COMPLETE_ALPHABET = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*()_+-=[]{};:,./<>? `'\""
MAX_CONSECUTIVE_WHITESPACES = 12
DEFAULT_MAX_CONSECUTIVE_WHITESPACES = 12
DEFAULT_FORCE_JSON_FIELD_ORDER = False
WHITESPACE_CHARACTERS = " \t\n\r"
BACKSLASH = "\\"
BACKSLASH_ESCAPING_CHARACTERS = '"\\/bfnrt' # Characters allowed after an escaping backslash, except unicode
BACKSLACH_UNICODE_ESCAPE = "u"

CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES = 'LMFE_MAX_CONSECUTIVE_WHITESPACES'
"""Environment variable for externally controlling how many consective whitespaces the
JsonSchemaParser will allow. Default: 12"""

CONFIG_ENV_VAR_LMFE_FORCE_JSON_FIELD_ORDER = 'LMFE_FORCE_JSON_FIELD_ORDER'
"""Environment variable for externally controlling whether the JsonSchemaParser will force
fields to appear in the order of the 'required' field in the schema. Default: false"""
21 changes: 11 additions & 10 deletions lmformatenforcer/jsonschemaparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .external.jsonschemaobject import JsonSchemaObject, json_schema_data_formats
from .exceptions import LMFormatEnforcerException
from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig, ForceStopParser, SequenceParser, StringParser, UnionParser
from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, MAX_CONSECUTIVE_WHITESPACES, WHITESPACE_CHARACTERS
from .consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, WHITESPACE_CHARACTERS
from .regexparser import RegexParser

# No need to include the 'integer' option in the anyOf, as it is a subset of 'number'
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_allowed_characters(self) -> str:
# characters when the object stack is empty (= we are done parsing)
allowed_characters = WHITESPACE_CHARACTERS

if self.num_consecutive_whitespaces >= MAX_CONSECUTIVE_WHITESPACES:
if self.num_consecutive_whitespaces >= self.config.max_consecutive_whitespaces:
# print("Filtering whitespace characters")
allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS)
return allowed_characters
Expand Down Expand Up @@ -302,10 +302,15 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
if new_character == '"':
possible_keys = None
if not self.is_dictionary:
possible_keys = list(self.schema_object.properties.keys())
possible_keys = list(
set(possible_keys).difference(self.existing_keys)
)
required_keys = self.schema_object.required or []
next_required_key = next((key for key in required_keys if key not in self.existing_keys), None)
if self.root.config.force_json_field_order and next_required_key:
possible_keys = [next_required_key]
else:
possible_keys = list(self.schema_object.properties.keys())
possible_keys = list(
set(possible_keys).difference(self.existing_keys)
)
# We send require_opening_quote=True and then add_character('"') instead of require_opening_quote=False
# Because there is a difference between "don't need a quote" and "received it before creating the parser"
key_parser = StringParsingState(
Expand All @@ -325,10 +330,6 @@ def add_character(self, new_character: str) -> CharacterLevelParser:
else:
value_schema = JsonSchemaParser.ANY_JSON_OBJECT_SCHEMA
else:
possible_keys = list(self.schema_object.properties.keys())
possible_keys = list(
set(possible_keys).difference(self.existing_keys)
)
value_schema = self.schema_object.properties[self.current_key]
self.current_key_parser = get_parser(
self.root, value_schema
Expand Down
1 change: 0 additions & 1 deletion lmformatenforcer/regexparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from interegular.fsm import anything_else

from .characterlevelparser import CharacterLevelParser, CharacterLevelParserConfig
from .consts import COMPLETE_ALPHABET

class RegexParser(CharacterLevelParser):
"""RegexParser is an example CharacterLevelParser that only allows strings that match a given regular expression."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "lm-format-enforcer"
version = "0.9.10"
version = "0.10.1"
description = "Enforce the output format (JSON Schema, Regex etc) of a language model"
authors = ["Noam Gat <[email protected]>"]
license = "MIT"
Expand Down
47 changes: 46 additions & 1 deletion tests/test_jsonschemaparser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from lmformatenforcer import JsonSchemaParser
from enum import Enum
import pytest
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS
from lmformatenforcer.consts import BACKSLASH, BACKSLASH_ESCAPING_CHARACTERS, CONFIG_ENV_VAR_LMFE_FORCE_JSON_FIELD_ORDER, CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES

from .common import assert_parser_with_string, CharacterNotAllowedException

Expand Down Expand Up @@ -498,3 +500,46 @@ def test_top_level_array_object():
invalid_result = valid_result[:-1]
_test_json_schema_parsing_with_string(valid_result, test_schema, True)
_test_json_schema_parsing_with_string(invalid_result, test_schema, False)


@contextmanager
def _temp_replace_env_var(env_var_name, temp_value):
try:
prev_env_var = os.environ.get(env_var_name, None)
if prev_env_var is not None:
os.environ.pop(env_var_name)
if temp_value is not None:
os.environ[env_var_name] = str(temp_value)
yield None
finally:
if prev_env_var is None:
if temp_value is not None:
os.environ.pop(env_var_name)
else:
os.environ[env_var_name] = prev_env_var


def test_control_json_force_field_order_via_env_var():
class TwoRequiredModel(BaseModel):
a: int
b: str
c: int = 1
schema = TwoRequiredModel.model_json_schema()
env_var_name = CONFIG_ENV_VAR_LMFE_FORCE_JSON_FIELD_ORDER
with _temp_replace_env_var(env_var_name, None):
# Check that the default is false
_test_json_schema_parsing_with_string('{"b": "X", "a": 1}', schema, True)
with _temp_replace_env_var(env_var_name, 'True'):
# Check that setting to true behaves correctly
_test_json_schema_parsing_with_string('{"b": "X", "a": 1}', schema, False)
_test_json_schema_parsing_with_string('{"a": 1, "b": "X"}', schema, True)


def test_max_whitespaces_via_env_var():
env_var_name = CONFIG_ENV_VAR_MAX_CONSECUTIVE_WHITESPACES
schema = SampleModel.model_json_schema()
base_answer = '{"num":$1}'
with _temp_replace_env_var(env_var_name, '8'):
for num_spaces in range(12):
expect_success = num_spaces <= 8
_test_json_schema_parsing_with_string(base_answer.replace("$", " " * num_spaces), schema, expect_success)

0 comments on commit 4167131

Please sign in to comment.