Skip to content

Commit

Permalink
Refactor: Extracted analysis logic from transformers integration so i…
Browse files Browse the repository at this point in the history
…t can be used in other inference engines
  • Loading branch information
noamgat committed Nov 1, 2023
1 parent a9b7f24 commit bc5d899
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 48 deletions.
9 changes: 8 additions & 1 deletion lmformatenforcer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
'RegexParser',
'JsonSchemaParser',
'TokenEnforcer',
'LMFormatEnforcerException']
'LMFormatEnforcerException',
'FormatEnforcerAnalyzer',]

from .characterlevelparser import CharacterLevelParser, StringParser
from .regexparser import RegexParser
from .jsonschemaparser import JsonSchemaParser
from .tokenenforcer import TokenEnforcer
from .exceptions import LMFormatEnforcerException
try:
from .analyzer import FormatEnforcerAnalyzer
except ImportError as e:
import logging
logging.warning(e)
FormatEnforcerAnalyzer = None
70 changes: 70 additions & 0 deletions lmformatenforcer/analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Dict, Hashable, List
try:
import numpy as np
import numpy.typing as npt
except ImportError as e:
raise ImportError('numpy is not installed. FormatEnforcerAnalyzer will not be available') from e
from . import TokenEnforcer


class FormatEnforcerAnalyzer:
"""A helper class to help analyze the format enforcer's behavior."""
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer
self.raw_logits: Dict[Hashable, npt.ArrayLike] = {}

def report_raw_logits(self, output_tokens: List[int], logits: npt.ArrayLike):
"""Report what logits were generated for a specific token sequence. The logits must be before any processing / filtering."""
self.raw_logits[tuple(output_tokens)] = logits

def generate_report_dict(self, output_tokens: List[int]) -> dict:
"""Generate a report dict containing the analysis results."""
scores_matrix: List[npt.ArrayLike] = []
allowed_tokens_matrix: List[List[int]] = []
for idx in range(len(output_tokens)):
prefix = output_tokens[:idx]
prefix_tuple = tuple(prefix)
if prefix_tuple in self.raw_logits:
scores_matrix.append(self.raw_logits[prefix_tuple])
allowed_tokens_matrix.append(self.token_enforcer.get_allowed_tokens(prefix))

logits = np.array(scores_matrix) # n_tokens * vocab_size
softmax_logits = _softmax(logits) # n_tokens * vocab_size
original_indices = softmax_logits.argmax(axis=1) # n_tokens
original_scores = _select_array(softmax_logits, original_indices) # n_tokens

single_token_dict: Dict[int, str] = dict(self.token_enforcer.regular_tokens)
def single_token_decoder(token_id: int) -> str:
if token_id in single_token_dict:
return single_token_dict[token_id]
return self.token_enforcer.decoder([token_id])

original_tokens = [single_token_decoder(idx) for idx in original_indices]

penalty_matrix = np.full_like(softmax_logits, -np.inf)
for row in range(penalty_matrix.shape[0]):
penalty_matrix[row][allowed_tokens_matrix[row]] = 0
enfored_softmax_logits = softmax_logits + penalty_matrix

enforced_indices = enfored_softmax_logits.argmax(axis=1)
enforced_scores = _select_array(enfored_softmax_logits, enforced_indices)

enforced_tokens = [single_token_decoder(idx) for idx in enforced_indices]
df_dict = {} # In order to minimize the package's dependencies, we don't create a dataframe, but create a dataframe-like dictionary instead.
df_dict['generated_token'] = enforced_tokens
df_dict['generated_token_idx'] = enforced_indices.tolist()
df_dict['generated_score'] = enforced_scores.tolist()
df_dict['leading_token'] = original_tokens
df_dict['leading_token_idx'] = original_indices.tolist()
df_dict['leading_score'] = original_scores.tolist()

return df_dict

def _softmax(arr: np.ndarray) -> np.ndarray:
"""Compute softmax values for each sets of scores in arr."""
e_arr = np.exp(arr)
return e_arr / np.sum(e_arr, axis=1, keepdims=True)

def _select_array(arr: np.ndarray, index_array: np.ndarray) -> np.ndarray:
# https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
return np.take_along_axis(arr, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1)
73 changes: 26 additions & 47 deletions lmformatenforcer/integrations/transformers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable, List, Tuple, Union
from typing import Any, Callable, List, Tuple, Union
try:
from transformers import AutoModelForCausalLM
from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor
from transformers.tokenization_utils import PreTrainedTokenizerBase
except ImportError:
raise ImportError('transformers is not installed. Please install it with "pip install transformers"')
raise ImportError('transformers is not installed. Please install it with "pip install transformers[torch]"')

try:
import torch
Expand All @@ -13,30 +13,34 @@

from ..characterlevelparser import CharacterLevelParser
from ..tokenenforcer import TokenEnforcer
from ..analyzer import FormatEnforcerAnalyzer

class LogitsSaverWarper(LogitsWarper):
def __init__(self) -> None:
self.scores = []
def __init__(self, analyzer: FormatEnforcerAnalyzer) -> None:
self.analyzer = analyzer

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cpu_scores = scores.clone().detach()
self.scores.append(cpu_scores)
cpu_inputs = input_ids.tolist()
cpu_scores = scores.tolist()
for single_batch_inputs, single_batch_scores in zip(cpu_inputs, cpu_scores):
self.analyzer.report_raw_logits(single_batch_inputs, single_batch_scores)
return scores

class LogitsSaverManager:
warper: LogitsSaverWarper

def __init__(self, model: AutoModelForCausalLM):
def __init__(self, model: AutoModelForCausalLM, analyzer: FormatEnforcerAnalyzer):
self.model = model
self.warper = None
self.old_warper = None
self.analyzer = analyzer

def replace_logits_warper(self, filter_func = None):
self.old_warper = self.model._get_logits_warper

def new_logits_warper(generation_config):
warpers = self.old_warper(generation_config)
self.warper = LogitsSaverWarper()
self.warper = LogitsSaverWarper(self.analyzer)
warpers.insert(0, self.warper)
if filter_func is not None:
processor = PrefixConstrainedLogitsProcessor(filter_func, 1)
Expand All @@ -47,22 +51,6 @@ def new_logits_warper(generation_config):
def unreplace_logits_warper(self):
self.model._get_logits_warper = self.old_warper

def get_generated_scores(self, token_sequence):
relevant_tokens = token_sequence[-len(self.warper.scores):]
scores_matrix = torch.concat(self.warper.scores) # n_tokens * vocab_size
probs = torch.softmax(scores_matrix, dim=1) # n_tokens * vocab_size
token_probs = probs[torch.arange(probs.size(0)), relevant_tokens] # n_tokens
return token_probs.to('cpu').tolist()

def get_leading_scores(self) -> Tuple[List[int], List[float]]:
scores_matrix = torch.concat(self.warper.scores) # n_tokens * vocab_size
probs = torch.softmax(scores_matrix, dim=1) # n_tokens * vocab_size
best_tokens = torch.argmax(scores_matrix, dim=1)
token_probs = probs[torch.arange(probs.size(0)), best_tokens] # n_tokens
token_probs_list = token_probs.to('cpu').tolist()
return best_tokens.tolist(), token_probs_list


def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str]]:
token_0 = tokenizer.encode("0")[-1]
regular_tokens = []
Expand All @@ -75,18 +63,22 @@ def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple
return regular_tokens


class TransformersPrefixAllowedTokensFn:
def __init__(self, token_enforcer: TokenEnforcer):
self.token_enforcer = token_enforcer

def __call__(self, batch_id: int, sent: torch.Tensor) -> List[int]:
token_sequence = sent.tolist()
return self.token_enforcer.get_allowed_tokens(token_sequence)


def build_transformers_prefix_allowed_tokens_fn(tokenizer: PreTrainedTokenizerBase,
character_level_parser: CharacterLevelParser) -> Callable[[int, torch.Tensor], List[int]]:
character_level_parser: CharacterLevelParser) -> TransformersPrefixAllowedTokensFn:
"""Build the prefix allowed tokens function that transformers will use to filter the tokens generated by the model. The result
can be passed to the prefix_allowed_tokens_fn parameter of the generate() method of transformers models or pipeline configurations."""
regular_tokens = _build_regular_tokens_list(tokenizer)
token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id)

def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> List[int]:
token_sequence = sent.tolist()
return token_enforcer.get_allowed_tokens(token_sequence)

return transformers_filter_allowed_tokens
return TransformersPrefixAllowedTokensFn(token_enforcer)


def generate_enforced(model: AutoModelForCausalLM,
Expand All @@ -111,7 +103,8 @@ def generate_enforced(model: AutoModelForCausalLM,
should_run_in_advanced_mode = return_dict_in_generate and output_scores and support_diagnostics

if should_run_in_advanced_mode:
logits_saver = LogitsSaverManager(model)
analyzer = FormatEnforcerAnalyzer(transformers_filter_allowed_tokens.token_enforcer)
logits_saver = LogitsSaverManager(model, analyzer)
logits_saver.replace_logits_warper(transformers_filter_allowed_tokens)
generate_kwargs = kwargs

Expand All @@ -120,21 +113,7 @@ def generate_enforced(model: AutoModelForCausalLM,
finally:
logits_saver.unreplace_logits_warper()

sequence = output.sequences[0] if return_dict_in_generate else output[0]
sequence = output.sequences[0]
generated_scores = logits_saver.get_generated_scores(sequence)
generated_tokens = sequence[-len(generated_scores):].to('cpu').tolist()
single_token_strs = [tokenizer.convert_ids_to_tokens([token], skip_special_tokens=False)[0] for token in generated_tokens]

leading_tokens, leading_scores = logits_saver.get_leading_scores()
leading_token_strs = [tokenizer.convert_ids_to_tokens([token], skip_special_tokens=False)[0] for token in leading_tokens]
df_dict = {} # In order to minimize the package's dependencies, we don't create a dataframe, but create a dataframe-like dictionary instead.
df_dict['generated_token'] = single_token_strs
df_dict['generated_token_idx'] = generated_tokens
df_dict['generated_score'] = generated_scores
df_dict['leading_token'] = leading_token_strs
df_dict['leading_token_idx'] = leading_tokens
df_dict['leading_score'] = leading_scores
df_dict = analyzer.generate_report_dict(output['sequences'][0].tolist())
output.enforced_scores = df_dict
else:
output = model.generate(**kwargs, prefix_allowed_tokens_fn=transformers_filter_allowed_tokens)
Expand Down
1 change: 1 addition & 0 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, regular_tokens: List[Tuple[int, str]],
self.decoder = decoder
self.eos_token_id = eos_token_id
self.allowed_token_cache: Dict[Hashable, List[int]] = {}
self.regular_tokens = regular_tokens

def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]:
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pytest = {version = "6.2.5", python = ">=3.8"}
coverage = {version = "^7.3.1", python = ">=3.8", extras = ["toml"]}
transformers = ">=4.28.1"
torch = {version = "^2.1.0+cpu", source = "pytorch"}
numpy = "^1.21.0"

[tool.poetry.group.samples.dependencies]
Flask = {version = "2.3.2", python = ">=3.8"}
Expand Down

0 comments on commit bc5d899

Please sign in to comment.