Skip to content

Commit

Permalink
Allow passing custom vocab_size to build_token_enforcer_tokenizer_dat…
Browse files Browse the repository at this point in the history
…a, allowing usage of models with smaller vocab sizes than their tokenizers (Llama 3.2 Vision models)
  • Loading branch information
noamgat committed Oct 8, 2024
1 parent 0a174f7 commit c61f00c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
12 changes: 7 additions & 5 deletions lmformatenforcer/integrations/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union
try:
from transformers import AutoModelForCausalLM
from transformers.generation.logits_process import LogitsWarper, PrefixConstrainedLogitsProcessor
Expand Down Expand Up @@ -53,10 +53,10 @@ def unreplace_logits_warper(self):
self.model._get_logits_warper = self.old_warper


def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:
def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase, vocab_size: int) -> List[Tuple[int, str, bool]]:
token_0 = tokenizer.encode("0")[-1]
regular_tokens = []
for token_idx in range(len(tokenizer)):
for token_idx in range(vocab_size):
if token_idx in tokenizer.all_special_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
Expand All @@ -73,8 +73,10 @@ def _decode_function(tokenizer: PreTrainedTokenizerBase, tokens: List[int]) -> s
return cleaned


def build_token_enforcer_tokenizer_data(tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
regular_tokens = _build_regular_tokens_list(tokenizer)
def build_token_enforcer_tokenizer_data(tokenizer: PreTrainedTokenizerBase,
vocab_size: Optional[int] = None) -> TokenEnforcerTokenizerData:
vocab_size = vocab_size or len(tokenizer)
regular_tokens = _build_regular_tokens_list(tokenizer, vocab_size)
decode_fn = functools.partial(_decode_function, tokenizer)
return TokenEnforcerTokenizerData(regular_tokens, decode_fn, tokenizer.eos_token_id)

Expand Down
7 changes: 5 additions & 2 deletions lmformatenforcer/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:

def build_vllm_token_enforcer_tokenizer_data(tokenizer: Union[vllm.LLM, PreTrainedTokenizerBase]) -> TokenEnforcerTokenizerData:
# There are many classes that can be passed here, this logic should work on all of them.
vocab_size = None
if hasattr(tokenizer, 'llm_engine'):
vocab_size = tokenizer.llm_engine.get_model_config().get_vocab_size()
if hasattr(tokenizer, 'get_tokenizer'):
tokenizer = tokenizer.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
return build_token_enforcer_tokenizer_data(tokenizer, vocab_size)
if hasattr(tokenizer, 'tokenizer'):
tokenizer = tokenizer.tokenizer
return build_token_enforcer_tokenizer_data(tokenizer)
return build_token_enforcer_tokenizer_data(tokenizer, vocab_size)


def build_vllm_logits_processor(llm: Union[vllm.LLM, PreTrainedTokenizerBase, TokenEnforcerTokenizerData],
Expand Down

0 comments on commit c61f00c

Please sign in to comment.