Skip to content

Commit

Permalink
Llama.cpp integration also supports analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat committed Nov 1, 2023
1 parent bc5d899 commit da7b722
Show file tree
Hide file tree
Showing 3 changed files with 769 additions and 36 deletions.
13 changes: 10 additions & 3 deletions lmformatenforcer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
import numpy as np
import numpy.typing as npt
except ImportError as e:
raise ImportError('numpy is not installed. FormatEnforcerAnalyzer will not be available') from e
from . import TokenEnforcer
class FormatEnforcerAnalyzer: # type: ignore
def __init__(self, *args, **kwargs):
pass
def report_raw_logits(self, *args, **kwargs):
pass
def generate_report_dict(self, *args, **kwargs):
return {}
raise ImportError('FormatEnforcerAnalyzer not available because numpy is not installed. Please install it with "pip install numpy"') from e

from . import TokenEnforcer

class FormatEnforcerAnalyzer:
"""A helper class to help analyze the format enforcer's behavior."""
Expand All @@ -18,7 +25,7 @@ def report_raw_logits(self, output_tokens: List[int], logits: npt.ArrayLike):
self.raw_logits[tuple(output_tokens)] = logits

def generate_report_dict(self, output_tokens: List[int]) -> dict:
"""Generate a report dict containing the analysis results."""
"""Generate a report dict containing the analysis results for a specific output token sequence."""
scores_matrix: List[npt.ArrayLike] = []
allowed_tokens_matrix: List[List[int]] = []
for idx in range(len(output_tokens)):
Expand Down
28 changes: 17 additions & 11 deletions lmformatenforcer/integrations/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from llama_cpp import Llama, LogitsProcessor
except ImportError:
raise ImportError('llama-cpp-python is not installed. Please install it with "pip install llama-cpp-python"')
from lmformatenforcer import CharacterLevelParser, TokenEnforcer
from lmformatenforcer import CharacterLevelParser, TokenEnforcer, FormatEnforcerAnalyzer
import numpy as np
import numpy.typing as npt
from typing import Tuple, List
Expand All @@ -24,23 +24,29 @@ def _build_regular_tokens_list(llm: Llama) -> List[Tuple[int, str]]:
return regular_tokens


def build_llamacpp_logits_processor(llm: Llama, character_level_parser: CharacterLevelParser) -> LogitsProcessor:
"""Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result
can be passed in the logits_processor list that is sent to the call or generate() method of llama.cpp models."""
regular_tokens = _build_regular_tokens_list(llm)
def decoder(sent: List[int]) -> str:
return llm.detokenize(sent).decode('utf-8')
token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, decoder, llm.token_eos())
class LlamaCppLogitsProcessor:
def __init__(self, token_enforcer: TokenEnforcer, analyze):
self.token_enforcer = token_enforcer
self.analyzer = FormatEnforcerAnalyzer(token_enforcer) if analyze else None

def llamacpp_logits_processor(input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
def __call__(self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
token_sequence = input_ids.tolist()
allowed_tokens = token_enforcer.get_allowed_tokens(token_sequence)
if self.analyzer:
self.analyzer.report_raw_logits(token_sequence, scores.tolist())
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
mask = np.ones(scores.shape, bool)
mask[allowed_tokens] = False
scores[mask] = float('-inf')
return scores

return llamacpp_logits_processor
def build_llamacpp_logits_processor(llm: Llama, character_level_parser: CharacterLevelParser, analyze: bool=False) -> LlamaCppLogitsProcessor:
"""Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result
can be passed in the logits_processor list that is sent to the call or generate() method of llama.cpp models."""
regular_tokens = _build_regular_tokens_list(llm)
def decoder(sent: List[int]) -> str:
return llm.detokenize(sent).decode('utf-8')
token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, decoder, llm.token_eos())
return LlamaCppLogitsProcessor(token_enforcer, analyze)


__all__ = ['build_llamacpp_logits_processor']
Loading

0 comments on commit da7b722

Please sign in to comment.