Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Wrapper Type #13

Merged
merged 3 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config/llm_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '<s>' + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + '</s>' }}{% endif %}{% endfor %}",
"system_prompt": true
},
"llama-3": {
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"system_prompt": false
},
"mistral": {
"template": "'<s>'+{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content']}}+'</s>'+{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"system_prompt": true
Expand Down
12 changes: 8 additions & 4 deletions env.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#GPT
AZURE_KEY=xxxx123
AZURE_VERSION=2023-07-01-preview
AZURE_ENDPOINT=http://gpt-4.com
OPENAI_API_TYPE="azure"
OPENAI_API_BASE="https://<your-endpoint.openai.azure.com/>"
OPENAI_API_KEY="your AzureOpenAI key"
OPENAI_API_VERSION="2023-05-15"

#TGI
TGI_ENDPOINT=""

#GEMINI
GEMINI_KEY=abcdefghtyuidkg
GEMINI_KEY="abcdefghtyuidkg"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
license = { text = "Apache 2.0 License" }

dependencies = [
"vllm>=0.5.2",
"accelerate>=0.30.1",
"peft>=0.11.1",
"bitsandbytes>=0.40.2",
Expand Down
8 changes: 6 additions & 2 deletions src/vieval/script_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ class ScriptArguments:
default=True,
metadata={"help": "Enable the TF32 mode (available in Ampere and newer GPUs)"},
)

cpu_offload_gb: Optional[bool] = field(
default=0, metadata={"help": "Offload computation to CPU"}
)
auto_find_batch_size: Optional[bool] = field(
default=True, metadata={"help": "Enable auto batch size"}
)
Expand Down Expand Up @@ -167,6 +169,9 @@ class ScriptArguments:
)

# Inference parameters
dtype: Optional[str] = field(
default="float16", metadata={"help": "Data type for model weights"} # float16
)
smoke_test: Optional[bool] = field(
default=False, metadata={"help": "Run a smoke test on a small dataset"}
)
Expand All @@ -182,7 +187,6 @@ class ScriptArguments:
cot: Optional[bool] = field(
default=False, metadata={"help": "Enable chain of thought when prompting MATH"}
)
tgi: Optional[str] = field(default="", metadata={"help": "Embed TGI endpoint"})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed"})
continue_infer: Optional[bool] = field(
default=False,
Expand Down
47 changes: 23 additions & 24 deletions src/vieval/tools/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from tqdm import tqdm
from ..utils.model import get_model
from ..wrapper import AzureGPTWrapper, TGIWrapper, GeminiWrapper, HFWrapper
from ..wrapper import OpenAIWrapper, TGIWrapper, GeminiWrapper, HFWrapper
from ..utils.utils import *
from ..utils.metric_utils import info_from_filename
from .metric_pipelines import MetricPipeline
Expand All @@ -28,24 +28,23 @@ def __init__(self, task, config):
# print(config.tgi)
if config.wtype == "tgi":
self.infer_pipeline = TGIWrapper(
api_endpoint=config.tgi,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
elif config.wtype == "hf":
# Load model
self.model, self.tokenizer = get_model(config=config)
self.model.eval()

self.infer_pipeline = HFWrapper(
model=self.model,
tokenizer=self.tokenizer,
config=config,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)

elif config.wtype == "azuregpt":
self.infer_pipeline = AzureGPTWrapper(
elif config.wtype == "vllm":
self.infer_pipeline = VLLMWrapper(
config=config,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
elif config.wtype == "openai":
self.infer_pipeline = OpenAIWrapper(
engine=config.model_name,
generation_config=GenerationConfig[extract_task],
)
Expand Down Expand Up @@ -146,7 +145,7 @@ def __question_answering(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x[0] for x in batch[ds_wrapper.answer]["text"]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -185,7 +184,7 @@ def __question_answering_without_context(
calib_probs = []
idx = 0
original_few_shot = []
calibration_few_shot = []
calib_few_shot = []
selected_sample = []
if self.continue_infer_data is not None:
predictions.extend(self.continue_infer_data["predictions"])
Expand Down Expand Up @@ -258,8 +257,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.answer]])
generation_probs.extend([x.tolist() for x in logprobs])
calib_probs.extend([x.tolist() for x in calibprob_batch])
generation_probs.extend(logprobs)
calib_probs.extend(calibprob_batch)
idx += 1
if idx % 100 == 0:
print(f"Saving results of {idx} batches")
Expand Down Expand Up @@ -327,7 +326,7 @@ def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.summarized_text]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -451,7 +450,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -623,7 +622,7 @@ def format_calib_fewshot(rec):
for x in batch[ds_wrapper.label]
]
)
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -756,7 +755,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(
[
[
Expand Down Expand Up @@ -932,7 +931,7 @@ def preprocessing_a_record(rec):
]
)

generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)
option_probs.extend(opt_calib_out)
idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -1017,7 +1016,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down Expand Up @@ -1322,8 +1321,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[target]])
generation_probs.extend([x.tolist() for x in logprobs])
calib_probs.extend([x.tolist() for x in calibprob_batch])
generation_probs.extend(logprobs)
calib_probs.extend(calibprob_batch)
if sub_task == "math":
math_problem_type.extend([x for x in batch[ds_wrapper.type]])
idx += 1
Expand Down Expand Up @@ -1417,7 +1416,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target_language]])
generation_probs.extend([x.tolist() for x in logprobs])
generation_probs.extend(logprobs)

idx += 1
if idx % 100 == 0:
Expand Down
5 changes: 2 additions & 3 deletions src/vieval/tools/wrapper/GeminiWrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import json
import os
import openai
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(self, model_name=None, generation_config=None):

def __call__(self, prompts, return_probs=False):
generations = []
generations_probs = [torch.tensor([])] * len(prompts)
generations_probs = [[]] * len(prompts)
num_generated_tokens = []
for prompt in prompts:
processed_prompt = [list(p.values())[1] for p in prompt]
Expand All @@ -74,7 +73,7 @@ def __call__(self, prompts, return_probs=False):

def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = [0] * len(prompts)
completions_logprobs = [torch.tensor([])] * len(prompts)
completions_logprobs = [[]] * len(prompts)
# Not Implement
return completions_logprobs, completions_num_tokens

Expand Down
14 changes: 8 additions & 6 deletions src/vieval/tools/wrapper/HFWrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
from .BaseWrapper import BaseWrapper
from ..utils.chat_template import apply_chat_template
from ..utils.model import get_model


class HFWrapper(BaseWrapper):
def __init__(self, model, tokenizer, generation_config, template=""):
self.model = model
self.tokenizer = tokenizer
def __init__(self, config, generation_config, template=None):
self.model, self.tokenizer = get_model(config=config)
self.model.eval()

self.generation_config = generation_config
self.model_template = template

Expand Down Expand Up @@ -51,7 +53,7 @@ def __call__(self, prompts, return_probs=False):
scores=generate_dict.scores,
normalize_logits=True,
)
generations_probs.extend(generation_probs.cpu().numpy())
generations_probs.extend(generation_probs.cpu().numpy().tolist())

return generations, generations_probs, num_generated_tokens

Expand All @@ -67,7 +69,7 @@ def compute_logprob_and_length(self, prompts, completions):
prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1

completion_tokens = self.tokenizer(
f"{completion} {self.tokenizer.eos_token}", return_tensors="pt"
f"{completion}{self.tokenizer.eos_token}", return_tensors="pt"
).to(
self.model.device
) # <s> SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE </s>
Expand Down Expand Up @@ -101,5 +103,5 @@ def compute_logprob_and_length(self, prompts, completions):
),
).squeeze(-1)
# >>> batch_size, sequence_length
completions_logprobs.append(logprobs.cpu().numpy())
completions_logprobs.append(logprobs.cpu().numpy().tolist())
return completions_logprobs, completions_num_tokens
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,26 @@
from .BaseWrapper import BaseWrapper


class AzureGPTWrapper(BaseWrapper):
class OpenAIWrapper(BaseWrapper):
def __init__(self, engine=None, generation_config=None):
self.generation_config = generation_config
self.model = openai.AzureOpenAI(
azure_endpoint=os.getenv("AZURE_ENDPOINT"),
api_key=os.getenv("AZURE_KEY"),
api_version=os.getenv("AZURE_VERSION"),
generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
generation_config["frequency_penalty"] = generation_config.pop(
"repetition_penalty"
)
self.generation_config = generation_config
self.model = openai.OpenAI()
self.engine = engine

def __call__(self, prompts, return_probs=False):
generations = []
generations_probs = [torch.tensor([])] * len(prompts)
generations_probs = [[]] * len(prompts)
num_generated_tokens = []
for prompt in prompts:

response = self.chat_completions_with_backoff(
model=self.engine,
messages=prompt,
temperature=self.generation_config["temperature"],
max_tokens=self.generation_config["max_new_tokens"],
top_p=0.95,
frequency_penalty=self.generation_config["repetition_penalty"],
**self.generation_config,
)

generations.append(response.choices[0].message.content)
Expand All @@ -37,7 +34,7 @@ def __call__(self, prompts, return_probs=False):

def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = [0] * len(prompts)
completions_logprobs = [torch.tensor([])] * len(prompts)
completions_logprobs = [[]] * len(prompts)
# TODO: Implement when OpenAI support logprobs of sentence
return completions_logprobs, completions_num_tokens

Expand Down
Loading
Loading