diff --git a/CHANGELOG.md b/CHANGELOG.md index ee06eba..0493d1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # LM Format Enforcer Changelog +## v0.5.1 +- Made it easier to report bugs in the library + ## v0.5.0 - Introduced FormatEnforcerAnalyzer to allow all inference engines to be analyzed in a unified way. (Was previously only available for transformers) - Added support for the analyser in llama.cpp, updated example notebook diff --git a/lmformatenforcer/jsonschemaparser.py b/lmformatenforcer/jsonschemaparser.py index 41c953a..97ca2ef 100644 --- a/lmformatenforcer/jsonschemaparser.py +++ b/lmformatenforcer/jsonschemaparser.py @@ -102,7 +102,7 @@ def get_parser( ending_characters: str, ) -> BaseParsingState: if value_schema is None: - raise LMFormatEnforcerException("value schema is None. This may be a bug in the library, please open an issue at https://github.com/noamgat/lm-format-enforcer/issues") + raise Exception("JsonSchemaParser: Value schema is None") # Sometimes the schema is a union of a type and null, so we need to get the first type if value_schema.anyOf and len(value_schema.anyOf) == 2 and value_schema.anyOf[1].type == 'null': value_schema = value_schema.anyOf[0] diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index ef461a2..8d47c58 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import Callable, Dict, Hashable, List, Optional, Tuple import logging + +from .exceptions import LMFormatEnforcerException from .characterlevelparser import CharacterLevelParser, ForceStopParser from .tokenizerprefixtree import TokenizerPrefixTree, TokenizerPrefixTreeNode @@ -66,22 +68,36 @@ def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: return new_state.allowed_tokens def _compute_allowed_tokens(self, state: 'TokenEnforcer.OutputTensorState'): - allowed_tokens: List[int] = [] - cache_key = state.parser.cache_key() - if cache_key is not None and cache_key in self.allowed_token_cache: - state.allowed_tokens = self.allowed_token_cache[cache_key] - return - shortcut_key = state.parser.shortcut_key() - self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key) - if state.parser.can_end(): - allowed_tokens.append(self.eos_token_id) - if not allowed_tokens: - raise ValueError(f"Parser reached state with no allowed tokens") - # root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) - # print(f"Allowing {len(allowed_tokens)} tokens after {state.str_so_far[len(root_state.str_so_far):]}") - state.allowed_tokens = allowed_tokens - if cache_key is not None: - self.allowed_token_cache[cache_key] = allowed_tokens + try: + allowed_tokens: List[int] = [] + cache_key = state.parser.cache_key() + if cache_key is not None and cache_key in self.allowed_token_cache: + state.allowed_tokens = self.allowed_token_cache[cache_key] + return + shortcut_key = state.parser.shortcut_key() + self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key) + if state.parser.can_end(): + allowed_tokens.append(self.eos_token_id) + if not allowed_tokens: + raise ValueError(f"Parser reached state with no allowed tokens") + # root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) + # print(f"Allowing {len(allowed_tokens)} tokens after {state.str_so_far[len(root_state.str_so_far):]}") + state.allowed_tokens = allowed_tokens + if cache_key is not None: + self.allowed_token_cache[cache_key] = allowed_tokens + except LMFormatEnforcerException: + # Getting an LMFormatEnforcerException means that we know what the user did wrong, + # and we can give a nice error message for them to fix. + raise + except Exception: + # Other exceptions are potential bugs and should be reported + root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser) + characters_in_root_node = state.str_so_far[len(root_state.str_so_far):] + logging.exception(f"Unknown LMFormatEnforcer Problem. Prefix: '{characters_in_root_node}'\n" + "Terminating the parser. Please open an issue at \n" + "https://github.com/noamgat/lm-format-enforcer/issues with the prefix and " + "CharacterLevelParser parameters") + state.allowed_tokens = [self.eos_token_id] def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: TokenizerPrefixTreeNode, allowed_tokens: List[int], shortcut_key: Optional[str]): allowed_tokens.extend(tree_node.tokens) diff --git a/pyproject.toml b/pyproject.toml index d8574b7..21eae0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lm-format-enforcer" -version = "0.5.0" +version = "0.5.1" description = "Enforce the output format (JSON Schema, Regex etc) of a language model" authors = ["Noam Gat "] license = "MIT"