From 22de027b12b8c8de8e05110cc483d1525acd3263 Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Wed, 25 Sep 2024 13:27:51 +0200 Subject: [PATCH] feat: use the vllm Mistral tokenizer wrapper as tokenizer --- lmformatenforcer/integrations/vllm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmformatenforcer/integrations/vllm.py b/lmformatenforcer/integrations/vllm.py index 1c1ebbb..b458757 100644 --- a/lmformatenforcer/integrations/vllm.py +++ b/lmformatenforcer/integrations/vllm.py @@ -1,6 +1,7 @@ try: import torch import vllm + from vllm.transformers_utils.tokenizer import MistralTokenizer from transformers import PreTrainedTokenizerBase except ImportError: raise ImportError('vllm is not installed. Please install it with "pip install vllm"') @@ -35,6 +36,8 @@ def build_vllm_token_enforcer_tokenizer_data(tokenizer: Union[vllm.LLM, PreTrain # There are many classes that can be passed here, this logic should work on all of them. if hasattr(tokenizer, 'get_tokenizer'): tokenizer = tokenizer.get_tokenizer() + if isinstance(tokenizer, MistralTokenizer): + return build_token_enforcer_tokenizer_data(tokenizer) if hasattr(tokenizer, 'tokenizer'): tokenizer = tokenizer.tokenizer return build_token_enforcer_tokenizer_data(tokenizer)