Skip to content

Commit

Permalink
fix: small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ledong0110 committed Aug 4, 2024
1 parent a6d8a6f commit b9064ac
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 58 deletions.
Empty file added datasets/.gitkeep
Empty file.
160 changes: 102 additions & 58 deletions src/vieval/tools/wrapper/TGIWrapper.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,123 @@
from vllm import LLM, SamplingParams
from typing import Dict, List
import backoff
import requests
from transformers import AutoTokenizer
import warnings
import os
import copy
from .BaseWrapper import BaseWrapper
from ..utils.chat_template import apply_chat_template


class VLLMWrapper(BaseWrapper):
def __init__(self, config, generation_config, template: Dict = None):
generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
generation_config["frequency_penalty"] = generation_config.pop(
"repetition_penalty"
)
self.model = LLM(
model=config.model_name,
cpu_offload_gb=config.cpu_offload_gb,
dtype=config.dtype,
)
self.generation_config = SamplingParams(
**generation_config, logprobs=1, prompt_logprobs=0
)
class TGIWrapper(BaseWrapper):
def __init__(self, generation_config, template=""):
self.api_endpoint = os.getenv("TGI_ENDPOINT")
self.generation_config = generation_config
self.model_template = template
self.model_info = self.get_model_info()
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_info["model_id"], trust_remote_code=True
)

def __call__(self, prompts, return_probs=False):
generations = []
generations_probs = []
num_generated_tokens = []
prompts = copy.deepcopy(prompts)
prompts = apply_chat_template(prompts, self.model_template)
try:
outputs = self.model.generate(prompts, self.generation_config)
for output in outputs:
generations.append(output.outputs[0].text)
generations_probs.append(
[
list(logprob.values())[0].logprob
for logprob in output.outputs[0].logprobs
]
for prompt in prompts:
try:
generate_dict = self.generate_with_backoff(
{
"inputs": prompt,
"parameters": {
"truncate": self.model_info["max_input_tokens"],
"details": True,
**self.generation_config,
},
}
)
num_generated_tokens.append(len(output.outputs[0].logprobs))
except Exception as e:
print(prompts)
raise e
except Exception as e:
print(e)
print(prompt)
raise e
(
generation,
generation_probs,
num_generated_token,
) = self.get_text_logprobs_tgi(generate_dict)

num_generated_tokens.extend(num_generated_token)
generations.extend(generation)

if return_probs:
# Inlcude probabilities of '</s>' token
generations_probs.extend(generation_probs)

return generations, generations_probs, num_generated_tokens

def compute_logprob_and_length(self, prompts, completions):
tokenizer = self.model.get_tokenizer()
completions_num_tokens = []
completions_logprobs = []
prompts = copy.deepcopy(prompts)
promtps = copy.deepcopy(prompts)
prompts = apply_chat_template(prompts, self.model_template)
tokenized_prompts = tokenizer(prompts)["input_ids"]
len_tokenized_prompts = [len(p) for p in tokenized_prompts]
completed_prompts = [
prompt + str(completion) + tokenizer.eos_token
for prompt, completion in zip(prompts, completions)
]
outputs = self.model.generate(
completed_prompts,
SamplingParams(
max_tokens=1,
prompt_logprobs=0,
ignore_eos=False,
skip_special_tokens=False,
),
)
for output, len_tokenized_prompt in zip(outputs, len_tokenized_prompts):
completions_num_tokens.append(
len(output.prompt_logprobs) - len_tokenized_prompt
)
completions_logprobs.append(
[
[
list(logprob.values())[0].logprob
for logprob in output.prompt_logprobs[len_tokenized_prompt:]
]
]
)
# tokenized_prompts = self.tokenizer(prompts)["input_ids"]
# len_tokenized_prompts = [len(p) for p in tokenized_prompts]
for prompt, completion in zip(prompts, completions):
try:
for prompt, completion in zip(prompts, completions):
prompt_tokens = self.generate_with_backoff(
{
"inputs": prompt,
"parameters": {
"truncate": self.model_info["max_input_tokens"],
"decoder_input_details": True,
"max_new_tokens": 1,
},
}
)["details"]["prefill"]
completion_w_prompt = self.generate_with_backoff(
{
"inputs": prompt + completion + self.tokenizer.eos_token,
"parameters": {
"truncate": self.model_info["max_input_tokens"],
"decoder_input_details": True,
"max_new_tokens": 1,
},
}
)["details"]["prefill"]
except Exception as e:
print(e)
print(prompt)
raise e
logprobs = [
list(
map(
lambda x: x["logprob"],
completion_w_prompt[len(prompt_tokens) :],
)
)
]
completions_logprobs.append(logprobs)
completions_num_tokens.append(len(logprobs[0]))

return completions_logprobs, completions_num_tokens

def get_model_info(self):
info = requests.get(self.api_endpoint + "/info", verify=False)
return info.json()

@backoff.on_exception(
backoff.expo, requests.exceptions.RequestException, max_tries=10
)
def generate_with_backoff(self, inputs):
generate_obj = requests.post(
self.api_endpoint + "/generate", json=inputs, verify=False
)
return generate_obj.json()

def get_text_logprobs_tgi(self, res):
return (
[res["generated_text"]],
[list(map(lambda x: x["logprob"], res["details"]["tokens"]))],
[res["details"]["generated_tokens"]],
)
2 changes: 2 additions & 0 deletions src/vieval/tools/wrapper/VLLMWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __call__(self, prompts, return_probs=False):
generations = []
generations_probs = []
num_generated_tokens = []
prompts = copy.deepcopy(prompts)
prompts = apply_chat_template(prompts, self.model_template)
try:
outputs = self.model.generate(prompts, self.generation_config)
Expand All @@ -45,6 +46,7 @@ def compute_logprob_and_length(self, prompts, completions):
tokenizer = self.model.get_tokenizer()
completions_num_tokens = []
completions_logprobs = []
prompts = copy.deepcopy(prompts)
prompts = apply_chat_template(prompts, self.model_template)
tokenized_prompts = tokenizer(prompts)["input_ids"]
len_tokenized_prompts = [len(p) for p in tokenized_prompts]
Expand Down

0 comments on commit b9064ac

Please sign in to comment.