diff --git a/lmformatenforcer/integrations/transformers.py b/lmformatenforcer/integrations/transformers.py index 3dcfa60..0bdfd2f 100644 --- a/lmformatenforcer/integrations/transformers.py +++ b/lmformatenforcer/integrations/transformers.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union try: from transformers import AutoModelForCausalLM from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor @@ -53,10 +53,10 @@ def unreplace_logits_warper(self): self.model._get_logits_warper = self.old_warper -def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]: +def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase, vocab_size: int) -> List[Tuple[int, str, bool]]: token_0 = tokenizer.encode("0")[-1] regular_tokens = [] - for token_idx in range(len(tokenizer)): + for token_idx in range(vocab_size): if token_idx in tokenizer.all_special_ids: continue # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word. @@ -73,8 +73,10 @@ def _decode_function(tokenizer: PreTrainedTokenizerBase, tokens: List[int]) -> s return cleaned -def build_token_enforcer_tokenizer_data(tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: - regular_tokens = _build_regular_tokens_list(tokenizer) +def build_token_enforcer_tokenizer_data(tokenizer: PreTrainedTokenizerBase, + vocab_size: Optional[int] = None) -> TokenEnforcerTokenizerData: + vocab_size = vocab_size or len(tokenizer) + regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size) decode_fn = functools.partial(_decode_function, tokenizer) return TokenEnforcerTokenizerData(regular_tokens, decode_fn, tokenizer.eos_token_id) diff --git a/lmformatenforcer/integrations/vllm.py b/lmformatenforcer/integrations/vllm.py index b458757..94f6c2c 100644 --- a/lmformatenforcer/integrations/vllm.py +++ b/lmformatenforcer/integrations/vllm.py @@ -34,13 +34,16 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: def build_vllm_token_enforcer_tokenizer_data(tokenizer: Union[vllm.LLM, PreTrainedTokenizerBase]) -> TokenEnforcerTokenizerData: # There are many classes that can be passed here, this logic should work on all of them. + vocab_size = None + if hasattr(tokenizer, 'llm_engine'): + vocab_size = tokenizer.llm_engine.get_model_config().get_vocab_size() if hasattr(tokenizer, 'get_tokenizer'): tokenizer = tokenizer.get_tokenizer() if isinstance(tokenizer, MistralTokenizer): - return build_token_enforcer_tokenizer_data(tokenizer) + return build_token_enforcer_tokenizer_data(tokenizer, vocab_size) if hasattr(tokenizer, 'tokenizer'): tokenizer = tokenizer.tokenizer - return build_token_enforcer_tokenizer_data(tokenizer) + return build_token_enforcer_tokenizer_data(tokenizer, vocab_size) def build_vllm_logits_processor(llm: Union[vllm.LLM, PreTrainedTokenizerBase, TokenEnforcerTokenizerData],