Skip to content

Commit

Permalink
feat: add VLLM wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ledong0110 committed Jul 30, 2024
1 parent f4ee6d9 commit 897fbac
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 81 deletions.
4 changes: 4 additions & 0 deletions config/llm_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '<s>' + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + '</s>' }}{% endif %}{% endfor %}",
"system_prompt": true
},
"llama-3": {
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"system_prompt": false
},
"mistral": {
"template": "'<s>'+{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content']}}+'</s>'+{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"system_prompt": true
Expand Down
12 changes: 8 additions & 4 deletions env.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#GPT
AZURE_KEY=xxxx123
AZURE_VERSION=2023-07-01-preview
AZURE_ENDPOINT=http://gpt-4.com
OPENAI_API_TYPE="azure"
OPENAI_API_BASE="https://<your-endpoint.openai.azure.com/>"
OPENAI_API_KEY="your AzureOpenAI key"
OPENAI_API_VERSION="2023-05-15"

#TGI
TGI_ENDPOINT=""

#GEMINI
GEMINI_KEY=abcdefghtyuidkg
GEMINI_KEY="abcdefghtyuidkg"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
license = { text = "Apache 2.0 License" }

dependencies = [
"vllm>=0.5.2",
"accelerate>=0.30.1",
"peft>=0.11.1",
"bitsandbytes>=0.40.2",
Expand Down
8 changes: 6 additions & 2 deletions src/vieval/script_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class ScriptArguments:
default=True,
metadata={"help": "Enable the TF32 mode (available in Ampere and newer GPUs)"},
)

cpu_offload_gb: Optional[bool] = field(
default=0, metadata={"help": "Offload computation to CPU"}
)
auto_find_batch_size: Optional[bool] = field(
default=True, metadata={"help": "Enable auto batch size"}
)
Expand Down Expand Up @@ -167,6 +169,9 @@ class ScriptArguments:
)

# Inference parameters
dtype: Optional[str] = field(
default="float16", metadata={"help": "Data type for model weights"} # float16
)
smoke_test: Optional[bool] = field(
default=False, metadata={"help": "Run a smoke test on a small dataset"}
)
Expand All @@ -182,7 +187,6 @@ class ScriptArguments:
cot: Optional[bool] = field(
default=False, metadata={"help": "Enable chain of thought when prompting MATH"}
)
tgi: Optional[str] = field(default="", metadata={"help": "Embed TGI endpoint"})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed"})
continue_infer: Optional[bool] = field(
default=False,
Expand Down
47 changes: 23 additions & 24 deletions src/vieval/tools/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from tqdm import tqdm
from ..utils.model import get_model
from ..wrapper import AzureGPTWrapper, TGIWrapper, GeminiWrapper, HFWrapper
from ..wrapper import OpenAIWrapper, TGIWrapper, GeminiWrapper, HFWrapper
from ..utils.utils import *
from ..utils.metric_utils import info_from_filename
from .metric_pipelines import MetricPipeline
Expand All @@ -28,24 +28,23 @@ def __init__(self, task, config):
# print(config.tgi)
if config.wtype == "tgi":
self.infer_pipeline = TGIWrapper(
api_endpoint=config.tgi,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
elif config.wtype == "hf":
# Load model
self.model, self.tokenizer = get_model(config=config)
self.model.eval()

self.infer_pipeline = HFWrapper(
model=self.model,
tokenizer=self.tokenizer,
config=config,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)

elif config.wtype == "azuregpt":
self.infer_pipeline = AzureGPTWrapper(
elif config.wtype == "vllm":
self.infer_pipeline = VLLMWrapper(
config=config,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
elif config.wtype == "openai":
self.infer_pipeline = OpenAIWrapper(
engine=config.model_name,
generation_config=GenerationConfig[extract_task],
)
Expand Down Expand Up @@ -146,7 +145,7 @@ def __question_answering(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x[0] for x in batch[ds_wrapper.answer]["text"]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -185,7 +184,7 @@ def __question_answering_without_context(
calib_probs = []
idx = 0
original_few_shot = []
calibration_few_shot = []
calib_few_shot = []
selected_sample = []
if self.continue_infer_data is not None:
predictions.extend(self.continue_infer_data["predictions"])
Expand Down Expand Up @@ -258,8 +257,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.answer]])
generation_probs.extend([x.tolist() for x in logprobs])
calib_probs.extend([x.tolist() for x in calibprob_batch])
generation_probs.extend(logprobs)
calib_probs.extend(calibprob_batch)
idx += 1
if idx % 100 == 0:
print(f"Saving results of {idx} batches")
Expand Down Expand Up @@ -327,7 +326,7 @@ def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.summarized_text]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -451,7 +450,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -623,7 +622,7 @@ def format_calib_fewshot(rec):
for x in batch[ds_wrapper.label]
]
)
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -756,7 +755,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -932,7 +931,7 @@ def preprocessing_a_record(rec):
]
)

generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(opt_calib_out)
idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -1017,7 +1016,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -1322,8 +1321,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[target]])
generation_probs.extend([x.tolist() for x in logprobs])
calib_probs.extend([x.tolist() for x in calibprob_batch])
generation_probs.extend(logprobs)
calib_probs.extend(calibprob_batch)
if sub_task == "math":
math_problem_type.extend([x for x in batch[ds_wrapper.type]])
idx += 1
Expand Down Expand Up @@ -1417,7 +1416,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target_language]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down
14 changes: 8 additions & 6 deletions src/vieval/tools/wrapper/HFWrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
from .BaseWrapper import BaseWrapper
from ..utils.chat_template import apply_chat_template
from ..utils.model import get_model


class HFWrapper(BaseWrapper):
def __init__(self, model, tokenizer, generation_config, template=""):
self.model = model
self.tokenizer = tokenizer
def __init__(self, config, generation_config, template=None):
self.model, self.tokenizer = get_model(config=config)
self.model.eval()

self.generation_config = generation_config
self.model_template = template

Expand Down Expand Up @@ -51,7 +53,7 @@ def __call__(self, prompts, return_probs=False):
scores=generate_dict.scores,
normalize_logits=True,
)
generations_probs.extend(generation_probs.cpu().numpy())
generations_probs.extend(generation_probs.cpu().numpy().tolist())

return generations, generations_probs, num_generated_tokens

Expand All @@ -67,7 +69,7 @@ def compute_logprob_and_length(self, prompts, completions):
prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1

completion_tokens = self.tokenizer(
f"{completion} {self.tokenizer.eos_token}", return_tensors="pt"
f"{completion}{self.tokenizer.eos_token}", return_tensors="pt"
).to(
self.model.device
) # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
Expand Down Expand Up @@ -101,5 +103,5 @@ def compute_logprob_and_length(self, prompts, completions):
),
).squeeze(-1)
# >>> batch_size, sequence_length
completions_logprobs.append(logprobs.cpu().numpy())
completions_logprobs.append(logprobs.cpu().numpy().tolist())
return completions_logprobs, completions_num_tokens
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from .BaseWrapper import BaseWrapper


class AzureGPTWrapper(BaseWrapper):
class OpenAIWrapper(BaseWrapper):
def __init__(self, engine=None, generation_config=None):
self.generation_config = generation_config
self.model = openai.AzureOpenAI(
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
api_key=os.getenv("AZURE_KEY"),
api_version=os.getenv("AZURE_VERSION"),
generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
generation_config["frequency_penalty"] = generation_config.pop(
"repetition_penalty"
)
self.generation_config = generation_config
self.model = openai.OpenAI()
self.engine = engine

def __call__(self, prompts, return_probs=False):
Expand All @@ -24,10 +24,7 @@ def __call__(self, prompts, return_probs=False):
response = self.chat_completions_with_backoff(
model=self.engine,
messages=prompt,
temperature=self.generation_config["temperature"],
max_tokens=self.generation_config["max_new_tokens"],
top_p=0.95,
frequency_penalty=self.generation_config["repetition_penalty"],
**self.generation_config,
)

generations.append(response.choices[0].message.content)
Expand Down
Loading

0 comments on commit 897fbac

Please sign in to comment.