From bc5d899aefb22b3554a48083d23f045db38b2698 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Wed, 1 Nov 2023 15:39:49 +0200 Subject: [PATCH] Refactor: Extracted analysis logic from transformers integration so it can be used in other inference engines --- lmformatenforcer/__init__.py | 9 ++- lmformatenforcer/analyzer.py | 70 ++++++++++++++++++ lmformatenforcer/integrations/transformers.py | 73 +++++++------------ lmformatenforcer/tokenenforcer.py | 1 + pyproject.toml | 1 + 5 files changed, 106 insertions(+), 48 deletions(-) create mode 100644 lmformatenforcer/analyzer.py diff --git a/lmformatenforcer/__init__.py b/lmformatenforcer/__init__.py index db8857c..85510e7 100644 --- a/lmformatenforcer/__init__.py +++ b/lmformatenforcer/__init__.py @@ -3,10 +3,17 @@ 'RegexParser', 'JsonSchemaParser', 'TokenEnforcer', - 'LMFormatEnforcerException'] + 'LMFormatEnforcerException', + 'FormatEnforcerAnalyzer',] from .characterlevelparser import CharacterLevelParser, StringParser from .regexparser import RegexParser from .jsonschemaparser import JsonSchemaParser from .tokenenforcer import TokenEnforcer from .exceptions import LMFormatEnforcerException +try: + from .analyzer import FormatEnforcerAnalyzer +except ImportError as e: + import logging + logging.warning(e) + FormatEnforcerAnalyzer = None diff --git a/lmformatenforcer/analyzer.py b/lmformatenforcer/analyzer.py new file mode 100644 index 0000000..f7decf1 --- /dev/null +++ b/lmformatenforcer/analyzer.py @@ -0,0 +1,70 @@ +from typing import Dict, Hashable, List +try: + import numpy as np + import numpy.typing as npt +except ImportError as e: + raise ImportError('numpy is not installed. FormatEnforcerAnalyzer will not be available') from e +from . import TokenEnforcer + + +class FormatEnforcerAnalyzer: + """A helper class to help analyze the format enforcer's behavior.""" + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + self.raw_logits: Dict[Hashable, npt.ArrayLike] = {} + + def report_raw_logits(self, output_tokens: List[int], logits: npt.ArrayLike): + """Report what logits were generated for a specific token sequence. The logits must be before any processing / filtering.""" + self.raw_logits[tuple(output_tokens)] = logits + + def generate_report_dict(self, output_tokens: List[int]) -> dict: + """Generate a report dict containing the analysis results.""" + scores_matrix: List[npt.ArrayLike] = [] + allowed_tokens_matrix: List[List[int]] = [] + for idx in range(len(output_tokens)): + prefix = output_tokens[:idx] + prefix_tuple = tuple(prefix) + if prefix_tuple in self.raw_logits: + scores_matrix.append(self.raw_logits[prefix_tuple]) + allowed_tokens_matrix.append(self.token_enforcer.get_allowed_tokens(prefix)) + + logits = np.array(scores_matrix) # n_tokens * vocab_size + softmax_logits = _softmax(logits) # n_tokens * vocab_size + original_indices = softmax_logits.argmax(axis=1) # n_tokens + original_scores = _select_array(softmax_logits, original_indices) # n_tokens + + single_token_dict: Dict[int, str] = dict(self.token_enforcer.regular_tokens) + def single_token_decoder(token_id: int) -> str: + if token_id in single_token_dict: + return single_token_dict[token_id] + return self.token_enforcer.decoder([token_id]) + + original_tokens = [single_token_decoder(idx) for idx in original_indices] + + penalty_matrix = np.full_like(softmax_logits, -np.inf) + for row in range(penalty_matrix.shape[0]): + penalty_matrix[row][allowed_tokens_matrix[row]] = 0 + enfored_softmax_logits = softmax_logits + penalty_matrix + + enforced_indices = enfored_softmax_logits.argmax(axis=1) + enforced_scores = _select_array(enfored_softmax_logits, enforced_indices) + + enforced_tokens = [single_token_decoder(idx) for idx in enforced_indices] + df_dict = {} # In order to minimize the package's dependencies, we don't create a dataframe, but create a dataframe-like dictionary instead. + df_dict['generated_token'] = enforced_tokens + df_dict['generated_token_idx'] = enforced_indices.tolist() + df_dict['generated_score'] = enforced_scores.tolist() + df_dict['leading_token'] = original_tokens + df_dict['leading_token_idx'] = original_indices.tolist() + df_dict['leading_score'] = original_scores.tolist() + + return df_dict + +def _softmax(arr: np.ndarray) -> np.ndarray: + """Compute softmax values for each sets of scores in arr.""" + e_arr = np.exp(arr) + return e_arr / np.sum(e_arr, axis=1, keepdims=True) + +def _select_array(arr: np.ndarray, index_array: np.ndarray) -> np.ndarray: + # https://numpy.org/doc/stable/reference/generated/numpy.argmax.html + return np.take_along_axis(arr, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) \ No newline at end of file diff --git a/lmformatenforcer/integrations/transformers.py b/lmformatenforcer/integrations/transformers.py index db5e79a..1a50d8e 100644 --- a/lmformatenforcer/integrations/transformers.py +++ b/lmformatenforcer/integrations/transformers.py @@ -1,10 +1,10 @@ -from typing import Callable, List, Tuple, Union +from typing import Any, Callable, List, Tuple, Union try: from transformers import AutoModelForCausalLM from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor from transformers.tokenization_utils import PreTrainedTokenizerBase except ImportError: - raise ImportError('transformers is not installed. Please install it with "pip install transformers"') + raise ImportError('transformers is not installed. Please install it with "pip install transformers[torch]"') try: import torch @@ -13,30 +13,34 @@ from ..characterlevelparser import CharacterLevelParser from ..tokenenforcer import TokenEnforcer +from ..analyzer import FormatEnforcerAnalyzer class LogitsSaverWarper(LogitsWarper): - def __init__(self) -> None: - self.scores = [] + def __init__(self, analyzer: FormatEnforcerAnalyzer) -> None: + self.analyzer = analyzer def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cpu_scores = scores.clone().detach() - self.scores.append(cpu_scores) + cpu_inputs = input_ids.tolist() + cpu_scores = scores.tolist() + for single_batch_inputs, single_batch_scores in zip(cpu_inputs, cpu_scores): + self.analyzer.report_raw_logits(single_batch_inputs, single_batch_scores) return scores class LogitsSaverManager: warper: LogitsSaverWarper - def __init__(self, model: AutoModelForCausalLM): + def __init__(self, model: AutoModelForCausalLM, analyzer: FormatEnforcerAnalyzer): self.model = model self.warper = None self.old_warper = None + self.analyzer = analyzer def replace_logits_warper(self, filter_func = None): self.old_warper = self.model._get_logits_warper def new_logits_warper(generation_config): warpers = self.old_warper(generation_config) - self.warper = LogitsSaverWarper() + self.warper = LogitsSaverWarper(self.analyzer) warpers.insert(0, self.warper) if filter_func is not None: processor = PrefixConstrainedLogitsProcessor(filter_func, 1) @@ -47,22 +51,6 @@ def new_logits_warper(generation_config): def unreplace_logits_warper(self): self.model._get_logits_warper = self.old_warper - def get_generated_scores(self, token_sequence): - relevant_tokens = token_sequence[-len(self.warper.scores):] - scores_matrix = torch.concat(self.warper.scores) # n_tokens * vocab_size - probs = torch.softmax(scores_matrix, dim=1) # n_tokens * vocab_size - token_probs = probs[torch.arange(probs.size(0)), relevant_tokens] # n_tokens - return token_probs.to('cpu').tolist() - - def get_leading_scores(self) -> Tuple[List[int], List[float]]: - scores_matrix = torch.concat(self.warper.scores) # n_tokens * vocab_size - probs = torch.softmax(scores_matrix, dim=1) # n_tokens * vocab_size - best_tokens = torch.argmax(scores_matrix, dim=1) - token_probs = probs[torch.arange(probs.size(0)), best_tokens] # n_tokens - token_probs_list = token_probs.to('cpu').tolist() - return best_tokens.tolist(), token_probs_list - - def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str]]: token_0 = tokenizer.encode("0")[-1] regular_tokens = [] @@ -75,18 +63,22 @@ def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple return regular_tokens +class TransformersPrefixAllowedTokensFn: + def __init__(self, token_enforcer: TokenEnforcer): + self.token_enforcer = token_enforcer + + def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]: + token_sequence = sent.tolist() + return self.token_enforcer.get_allowed_tokens(token_sequence) + + def build_transformers_prefix_allowed_tokens_fn(tokenizer: PreTrainedTokenizerBase, - character_level_parser: CharacterLevelParser) -> Callable[[int, torch.Tensor], List[int]]: + character_level_parser: CharacterLevelParser) -> TransformersPrefixAllowedTokensFn: """Build the prefix allowed tokens function that transformers will use to filter the tokens generated by the model. The result can be passed to the prefix_allowed_tokens_fn parameter of the generate() method of transformers models or pipeline configurations.""" regular_tokens = _build_regular_tokens_list(tokenizer) token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id) - - def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> List[int]: - token_sequence = sent.tolist() - return token_enforcer.get_allowed_tokens(token_sequence) - - return transformers_filter_allowed_tokens + return TransformersPrefixAllowedTokensFn(token_enforcer) def generate_enforced(model: AutoModelForCausalLM, @@ -111,7 +103,8 @@ def generate_enforced(model: AutoModelForCausalLM, should_run_in_advanced_mode = return_dict_in_generate and output_scores and support_diagnostics if should_run_in_advanced_mode: - logits_saver = LogitsSaverManager(model) + analyzer = FormatEnforcerAnalyzer(transformers_filter_allowed_tokens.token_enforcer) + logits_saver = LogitsSaverManager(model, analyzer) logits_saver.replace_logits_warper(transformers_filter_allowed_tokens) generate_kwargs = kwargs @@ -120,21 +113,7 @@ def generate_enforced(model: AutoModelForCausalLM, finally: logits_saver.unreplace_logits_warper() - sequence = output.sequences[0] if return_dict_in_generate else output[0] - sequence = output.sequences[0] - generated_scores = logits_saver.get_generated_scores(sequence) - generated_tokens = sequence[-len(generated_scores):].to('cpu').tolist() - single_token_strs = [tokenizer.convert_ids_to_tokens([token], skip_special_tokens=False)[0] for token in generated_tokens] - - leading_tokens, leading_scores = logits_saver.get_leading_scores() - leading_token_strs = [tokenizer.convert_ids_to_tokens([token], skip_special_tokens=False)[0] for token in leading_tokens] - df_dict = {} # In order to minimize the package's dependencies, we don't create a dataframe, but create a dataframe-like dictionary instead. - df_dict['generated_token'] = single_token_strs - df_dict['generated_token_idx'] = generated_tokens - df_dict['generated_score'] = generated_scores - df_dict['leading_token'] = leading_token_strs - df_dict['leading_token_idx'] = leading_tokens - df_dict['leading_score'] = leading_scores + df_dict = analyzer.generate_report_dict(output['sequences'][0].tolist()) output.enforced_scores = df_dict else: output = model.generate(**kwargs, prefix_allowed_tokens_fn=transformers_filter_allowed_tokens) diff --git a/lmformatenforcer/tokenenforcer.py b/lmformatenforcer/tokenenforcer.py index 5558315..ef461a2 100644 --- a/lmformatenforcer/tokenenforcer.py +++ b/lmformatenforcer/tokenenforcer.py @@ -32,6 +32,7 @@ def __init__(self, regular_tokens: List[Tuple[int, str]], self.decoder = decoder self.eos_token_id = eos_token_id self.allowed_token_cache: Dict[Hashable, List[int]] = {} + self.regular_tokens = regular_tokens def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]: """ diff --git a/pyproject.toml b/pyproject.toml index 75a926f..491aeab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ pytest = {version = "6.2.5", python = ">=3.8"} coverage = {version = "^7.3.1", python = ">=3.8", extras = ["toml"]} transformers = ">=4.28.1" torch = {version = "^2.1.0+cpu", source = "pytorch"} +numpy = "^1.21.0" [tool.poetry.group.samples.dependencies] Flask = {version = "2.3.2", python = ">=3.8"}