diff --git a/datasets/.gitkeep b/datasets/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/vieval/tools/wrapper/TGIWrapper.py b/src/vieval/tools/wrapper/TGIWrapper.py index 66b179d..88ae739 100644 --- a/src/vieval/tools/wrapper/TGIWrapper.py +++ b/src/vieval/tools/wrapper/TGIWrapper.py @@ -1,25 +1,22 @@ -from vllm import LLM, SamplingParams -from typing import Dict, List +import backoff +import requests +from transformers import AutoTokenizer +import warnings +import os import copy from .BaseWrapper import BaseWrapper from ..utils.chat_template import apply_chat_template -class VLLMWrapper(BaseWrapper): - def __init__(self, config, generation_config, template: Dict = None): - generation_config["max_tokens"] = generation_config.pop("max_new_tokens") - generation_config["frequency_penalty"] = generation_config.pop( - "repetition_penalty" - ) - self.model = LLM( - model=config.model_name, - cpu_offload_gb=config.cpu_offload_gb, - dtype=config.dtype, - ) - self.generation_config = SamplingParams( - **generation_config, logprobs=1, prompt_logprobs=0 - ) +class TGIWrapper(BaseWrapper): + def __init__(self, generation_config, template=""): + self.api_endpoint = os.getenv("TGI_ENDPOINT") + self.generation_config = generation_config self.model_template = template + self.model_info = self.get_model_info() + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_info["model_id"], trust_remote_code=True + ) def __call__(self, prompts, return_probs=False): generations = [] @@ -27,53 +24,100 @@ def __call__(self, prompts, return_probs=False): num_generated_tokens = [] prompts = copy.deepcopy(prompts) prompts = apply_chat_template(prompts, self.model_template) - try: - outputs = self.model.generate(prompts, self.generation_config) - for output in outputs: - generations.append(output.outputs[0].text) - generations_probs.append( - [ - list(logprob.values())[0].logprob - for logprob in output.outputs[0].logprobs - ] + for prompt in prompts: + try: + generate_dict = self.generate_with_backoff( + { + "inputs": prompt, + "parameters": { + "truncate": self.model_info["max_input_tokens"], + "details": True, + **self.generation_config, + }, + } ) - num_generated_tokens.append(len(output.outputs[0].logprobs)) - except Exception as e: - print(prompts) - raise e + except Exception as e: + print(e) + print(prompt) + raise e + ( + generation, + generation_probs, + num_generated_token, + ) = self.get_text_logprobs_tgi(generate_dict) + + num_generated_tokens.extend(num_generated_token) + generations.extend(generation) + + if return_probs: + # Inlcude probabilities of '' token + generations_probs.extend(generation_probs) + return generations, generations_probs, num_generated_tokens def compute_logprob_and_length(self, prompts, completions): - tokenizer = self.model.get_tokenizer() completions_num_tokens = [] completions_logprobs = [] - prompts = copy.deepcopy(prompts) + promtps = copy.deepcopy(prompts) prompts = apply_chat_template(prompts, self.model_template) - tokenized_prompts = tokenizer(prompts)["input_ids"] - len_tokenized_prompts = [len(p) for p in tokenized_prompts] - completed_prompts = [ - prompt + str(completion) + tokenizer.eos_token - for prompt, completion in zip(prompts, completions) - ] - outputs = self.model.generate( - completed_prompts, - SamplingParams( - max_tokens=1, - prompt_logprobs=0, - ignore_eos=False, - skip_special_tokens=False, - ), - ) - for output, len_tokenized_prompt in zip(outputs, len_tokenized_prompts): - completions_num_tokens.append( - len(output.prompt_logprobs) - len_tokenized_prompt - ) - completions_logprobs.append( - [ - [ - list(logprob.values())[0].logprob - for logprob in output.prompt_logprobs[len_tokenized_prompt:] - ] - ] - ) + # tokenized_prompts = self.tokenizer(prompts)["input_ids"] + # len_tokenized_prompts = [len(p) for p in tokenized_prompts] + for prompt, completion in zip(prompts, completions): + try: + for prompt, completion in zip(prompts, completions): + prompt_tokens = self.generate_with_backoff( + { + "inputs": prompt, + "parameters": { + "truncate": self.model_info["max_input_tokens"], + "decoder_input_details": True, + "max_new_tokens": 1, + }, + } + )["details"]["prefill"] + completion_w_prompt = self.generate_with_backoff( + { + "inputs": prompt + completion + self.tokenizer.eos_token, + "parameters": { + "truncate": self.model_info["max_input_tokens"], + "decoder_input_details": True, + "max_new_tokens": 1, + }, + } + )["details"]["prefill"] + except Exception as e: + print(e) + print(prompt) + raise e + logprobs = [ + list( + map( + lambda x: x["logprob"], + completion_w_prompt[len(prompt_tokens) :], + ) + ) + ] + completions_logprobs.append(logprobs) + completions_num_tokens.append(len(logprobs[0])) + return completions_logprobs, completions_num_tokens + + def get_model_info(self): + info = requests.get(self.api_endpoint + "/info", verify=False) + return info.json() + + @backoff.on_exception( + backoff.expo, requests.exceptions.RequestException, max_tries=10 + ) + def generate_with_backoff(self, inputs): + generate_obj = requests.post( + self.api_endpoint + "/generate", json=inputs, verify=False + ) + return generate_obj.json() + + def get_text_logprobs_tgi(self, res): + return ( + [res["generated_text"]], + [list(map(lambda x: x["logprob"], res["details"]["tokens"]))], + [res["details"]["generated_tokens"]], + ) \ No newline at end of file diff --git a/src/vieval/tools/wrapper/VLLMWrapper.py b/src/vieval/tools/wrapper/VLLMWrapper.py index 34ed26c..56cdb43 100644 --- a/src/vieval/tools/wrapper/VLLMWrapper.py +++ b/src/vieval/tools/wrapper/VLLMWrapper.py @@ -24,6 +24,7 @@ def __call__(self, prompts, return_probs=False): generations = [] generations_probs = [] num_generated_tokens = [] + prompts = copy.deepcopy(prompts) prompts = apply_chat_template(prompts, self.model_template) try: outputs = self.model.generate(prompts, self.generation_config) @@ -45,6 +46,7 @@ def compute_logprob_and_length(self, prompts, completions): tokenizer = self.model.get_tokenizer() completions_num_tokens = [] completions_logprobs = [] + prompts = copy.deepcopy(prompts) prompts = apply_chat_template(prompts, self.model_template) tokenized_prompts = tokenizer(prompts)["input_ids"] len_tokenized_prompts = [len(p) for p in tokenized_prompts]