From 463339e61b3f04ea3165c073572e133d994e50db Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Wed, 6 Dec 2023 17:18:53 +0200 Subject: [PATCH] Reduced cuda memory allocations in vllm integration, CPU allocations in llamacpp integration. --- lmformatenforcer/integrations/llamacpp.py | 11 ++++++++--- lmformatenforcer/integrations/vllm.py | 22 ++++++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/lmformatenforcer/integrations/llamacpp.py b/lmformatenforcer/integrations/llamacpp.py index bd4eeee..5d7cfa0 100644 --- a/lmformatenforcer/integrations/llamacpp.py +++ b/lmformatenforcer/integrations/llamacpp.py @@ -31,16 +31,21 @@ class LlamaCppLogitsProcessor: def __init__(self, token_enforcer: TokenEnforcer, analyze): self.token_enforcer = token_enforcer self.analyzer = FormatEnforcerAnalyzer(token_enforcer) if analyze else None + self.mask = None def __call__(self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]) -> npt.NDArray[np.single]: token_sequence = input_ids.tolist() if self.analyzer: self.analyzer.report_raw_logits(token_sequence, scores.tolist()) allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) - mask = np.ones(scores.shape, bool) - mask[allowed_tokens] = False - scores[mask] = float('-inf') + if self.mask is None: + self.mask = np.ones(scores.shape, bool) + else: + self.mask.fill(True) + self.mask[allowed_tokens] = False + scores[self.mask] = float('-inf') return scores + def build_llamacpp_logits_processor(llm: Llama, character_level_parser: CharacterLevelParser, analyze: bool=False) -> LlamaCppLogitsProcessor: """Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result diff --git a/lmformatenforcer/integrations/vllm.py b/lmformatenforcer/integrations/vllm.py index 040e2c4..c7aa597 100644 --- a/lmformatenforcer/integrations/vllm.py +++ b/lmformatenforcer/integrations/vllm.py @@ -1,11 +1,12 @@ try: import torch import vllm + from transformers import PreTrainedTokenizerBase except ImportError: raise ImportError('vllm is not installed. Please install it with "pip install vllm"') from lmformatenforcer import CharacterLevelParser, TokenEnforcer, FormatEnforcerAnalyzer from lmformatenforcer.integrations.transformers import build_regular_tokens_list -from typing import List +from typing import List, Optional, Union import math @@ -13,23 +14,32 @@ class VLLMLogitsProcessor: def __init__(self, token_enforcer: TokenEnforcer, analyze): self.token_enforcer = token_enforcer self.analyzer = FormatEnforcerAnalyzer(token_enforcer) if analyze else None + self.mask: Optional[torch.Tensor] = None def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: token_sequence = input_ids if self.analyzer: self.analyzer.report_raw_logits(token_sequence, scores.tolist()) allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence) - mask = torch.full_like(scores, -math.inf) - mask[allowed_tokens] = 0 - scores = scores + mask + if self.mask is not None: + self.mask.fill_(-math.inf) + else: + # We create it here because full_like() also copies the device and dtype + self.mask = torch.full_like(scores, -math.inf) + self.mask[allowed_tokens] = 0 + scores = scores + self.mask return scores -def build_vllm_logits_processor(llm: vllm.LLM, character_level_parser: CharacterLevelParser, analyze: bool=False) -> VLLMLogitsProcessor: +def build_vllm_logits_processor(llm: Union[vllm.LLM, PreTrainedTokenizerBase], + character_level_parser: CharacterLevelParser, + analyze: bool=False) -> VLLMLogitsProcessor: """Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result can be passed in the logits_processor list that is sent to the call or generate() method of llama.cpp models.""" - tokenizer = llm.get_tokenizer() + tokenizer = llm.get_tokenizer() if isinstance(llm, vllm.LLM) else llm regular_tokens = build_regular_tokens_list(tokenizer) + if tokenizer.eos_token_id is None: + raise ValueError('The tokenizer must have an EOS token') token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id) return VLLMLogitsProcessor(token_enforcer, analyze)