From 897fbac42f381a13a5dc3f0ad1e93de466ce3d42 Mon Sep 17 00:00:00 2001
From: ledong0110 <74060032+ledong0110@users.noreply.github.com>
Date: Wed, 31 Jul 2024 02:03:41 +0700
Subject: [PATCH 1/3] feat: add VLLM wrapper
---
config/llm_template.json | 4 +
env.template | 12 ++-
pyproject.toml | 1 +
src/vieval/script_arguments.py | 8 +-
src/vieval/tools/pipelines/pipelines.py | 47 ++++++-----
src/vieval/tools/wrapper/HFWrapper.py | 14 ++--
.../{AzureGPTWrapper.py => OpenAIWrapper.py} | 17 ++--
src/vieval/tools/wrapper/TGIWrapper.py | 78 +++++++++++--------
src/vieval/tools/wrapper/VLLMWrapper.py | 77 ++++++++++++++++++
src/vieval/tools/wrapper/__init__.py | 5 +-
10 files changed, 182 insertions(+), 81 deletions(-)
rename src/vieval/tools/wrapper/{AzureGPTWrapper.py => OpenAIWrapper.py} (73%)
create mode 100644 src/vieval/tools/wrapper/VLLMWrapper.py
diff --git a/config/llm_template.json b/config/llm_template.json
index ab39672..726cb7f 100644
--- a/config/llm_template.json
+++ b/config/llm_template.json
@@ -7,6 +7,10 @@
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '' + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + '' }}{% endif %}{% endfor %}",
"system_prompt": true
},
+ "llama-3": {
+ "template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
+ "system_prompt": false
+ },
"mistral": {
"template": "''+{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content']}}+''+{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"system_prompt": true
diff --git a/env.template b/env.template
index 42d1700..6768f04 100644
--- a/env.template
+++ b/env.template
@@ -1,7 +1,11 @@
#GPT
-AZURE_KEY=xxxx123
-AZURE_VERSION=2023-07-01-preview
-AZURE_ENDPOINT=http://gpt-4.com
+OPENAI_API_TYPE="azure"
+OPENAI_API_BASE="https://"
+OPENAI_API_KEY="your AzureOpenAI key"
+OPENAI_API_VERSION="2023-05-15"
+
+#TGI
+TGI_ENDPOINT=""
#GEMINI
-GEMINI_KEY=abcdefghtyuidkg
\ No newline at end of file
+GEMINI_KEY="abcdefghtyuidkg"
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index edef8bd..dca667d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ classifiers = [
license = { text = "Apache 2.0 License" }
dependencies = [
+ "vllm>=0.5.2",
"accelerate>=0.30.1",
"peft>=0.11.1",
"bitsandbytes>=0.40.2",
diff --git a/src/vieval/script_arguments.py b/src/vieval/script_arguments.py
index 98ab7ca..475559e 100644
--- a/src/vieval/script_arguments.py
+++ b/src/vieval/script_arguments.py
@@ -90,7 +90,9 @@ class ScriptArguments:
default=True,
metadata={"help": "Enable the TF32 mode (available in Ampere and newer GPUs)"},
)
-
+ cpu_offload_gb: Optional[bool] = field(
+ default=0, metadata={"help": "Offload computation to CPU"}
+ )
auto_find_batch_size: Optional[bool] = field(
default=True, metadata={"help": "Enable auto batch size"}
)
@@ -167,6 +169,9 @@ class ScriptArguments:
)
# Inference parameters
+ dtype: Optional[str] = field(
+ default="float16", metadata={"help": "Data type for model weights"} # float16
+ )
smoke_test: Optional[bool] = field(
default=False, metadata={"help": "Run a smoke test on a small dataset"}
)
@@ -182,7 +187,6 @@ class ScriptArguments:
cot: Optional[bool] = field(
default=False, metadata={"help": "Enable chain of thought when prompting MATH"}
)
- tgi: Optional[str] = field(default="", metadata={"help": "Embed TGI endpoint"})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed"})
continue_infer: Optional[bool] = field(
default=False,
diff --git a/src/vieval/tools/pipelines/pipelines.py b/src/vieval/tools/pipelines/pipelines.py
index 4d725e2..e19e6f4 100644
--- a/src/vieval/tools/pipelines/pipelines.py
+++ b/src/vieval/tools/pipelines/pipelines.py
@@ -4,7 +4,7 @@
import json
from tqdm import tqdm
from ..utils.model import get_model
-from ..wrapper import AzureGPTWrapper, TGIWrapper, GeminiWrapper, HFWrapper
+from ..wrapper import OpenAIWrapper, TGIWrapper, GeminiWrapper, HFWrapper
from ..utils.utils import *
from ..utils.metric_utils import info_from_filename
from .metric_pipelines import MetricPipeline
@@ -28,24 +28,23 @@ def __init__(self, task, config):
# print(config.tgi)
if config.wtype == "tgi":
self.infer_pipeline = TGIWrapper(
- api_endpoint=config.tgi,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
elif config.wtype == "hf":
- # Load model
- self.model, self.tokenizer = get_model(config=config)
- self.model.eval()
-
self.infer_pipeline = HFWrapper(
- model=self.model,
- tokenizer=self.tokenizer,
+ config=config,
generation_config=GenerationConfig[extract_task],
template=LLM_TEMPLATE[config.ptemplate],
)
-
- elif config.wtype == "azuregpt":
- self.infer_pipeline = AzureGPTWrapper(
+ elif config.wtype == "vllm":
+ self.infer_pipeline = VLLMWrapper(
+ config=config,
+ generation_config=GenerationConfig[extract_task],
+ template=LLM_TEMPLATE[config.ptemplate],
+ )
+ elif config.wtype == "openai":
+ self.infer_pipeline = OpenAIWrapper(
engine=config.model_name,
generation_config=GenerationConfig[extract_task],
)
@@ -146,7 +145,7 @@ def __question_answering(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x[0] for x in batch[ds_wrapper.answer]["text"]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
idx += 1
if idx % 100 == 0:
@@ -185,7 +184,7 @@ def __question_answering_without_context(
calib_probs = []
idx = 0
original_few_shot = []
- calibration_few_shot = []
+ calib_few_shot = []
selected_sample = []
if self.continue_infer_data is not None:
predictions.extend(self.continue_infer_data["predictions"])
@@ -258,8 +257,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.answer]])
- generation_probs.extend([x.tolist() for x in logprobs])
- calib_probs.extend([x.tolist() for x in calibprob_batch])
+ generation_probs.extend(logprobs)
+ calib_probs.extend(calibprob_batch)
idx += 1
if idx % 100 == 0:
print(f"Saving results of {idx} batches")
@@ -327,7 +326,7 @@ def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.summarized_text]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
idx += 1
if idx % 100 == 0:
@@ -451,7 +450,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
option_probs.extend(
[
[
@@ -623,7 +622,7 @@ def format_calib_fewshot(rec):
for x in batch[ds_wrapper.label]
]
)
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
option_probs.extend(
[
[
@@ -756,7 +755,7 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x.item() for x in batch[ds_wrapper.label]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
option_probs.extend(
[
[
@@ -932,7 +931,7 @@ def preprocessing_a_record(rec):
]
)
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
option_probs.extend(opt_calib_out)
idx += 1
if idx % 100 == 0:
@@ -1017,7 +1016,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
idx += 1
if idx % 100 == 0:
@@ -1322,8 +1321,8 @@ def preprocessing_a_record(rec):
)
predictions.extend(results)
references.extend([x for x in batch[target]])
- generation_probs.extend([x.tolist() for x in logprobs])
- calib_probs.extend([x.tolist() for x in calibprob_batch])
+ generation_probs.extend(logprobs)
+ calib_probs.extend(calibprob_batch)
if sub_task == "math":
math_problem_type.extend([x for x in batch[ds_wrapper.type]])
idx += 1
@@ -1417,7 +1416,7 @@ def preprocessing_a_record(rec):
results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True)
predictions.extend(results)
references.extend([x for x in batch[ds_wrapper.target_language]])
- generation_probs.extend([x.tolist() for x in logprobs])
+ generation_probs.extend(logprobs)
idx += 1
if idx % 100 == 0:
diff --git a/src/vieval/tools/wrapper/HFWrapper.py b/src/vieval/tools/wrapper/HFWrapper.py
index 4aeb27b..581be80 100644
--- a/src/vieval/tools/wrapper/HFWrapper.py
+++ b/src/vieval/tools/wrapper/HFWrapper.py
@@ -1,12 +1,14 @@
import torch
from .BaseWrapper import BaseWrapper
from ..utils.chat_template import apply_chat_template
+from ..utils.model import get_model
class HFWrapper(BaseWrapper):
- def __init__(self, model, tokenizer, generation_config, template=""):
- self.model = model
- self.tokenizer = tokenizer
+ def __init__(self, config, generation_config, template=None):
+ self.model, self.tokenizer = get_model(config=config)
+ self.model.eval()
+
self.generation_config = generation_config
self.model_template = template
@@ -51,7 +53,7 @@ def __call__(self, prompts, return_probs=False):
scores=generate_dict.scores,
normalize_logits=True,
)
- generations_probs.extend(generation_probs.cpu().numpy())
+ generations_probs.extend(generation_probs.cpu().numpy().tolist())
return generations, generations_probs, num_generated_tokens
@@ -67,7 +69,7 @@ def compute_logprob_and_length(self, prompts, completions):
prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1
completion_tokens = self.tokenizer(
- f"{completion} {self.tokenizer.eos_token}", return_tensors="pt"
+ f"{completion}{self.tokenizer.eos_token}", return_tensors="pt"
).to(
self.model.device
) # SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE
@@ -101,5 +103,5 @@ def compute_logprob_and_length(self, prompts, completions):
),
).squeeze(-1)
# >>> batch_size, sequence_length
- completions_logprobs.append(logprobs.cpu().numpy())
+ completions_logprobs.append(logprobs.cpu().numpy().tolist())
return completions_logprobs, completions_num_tokens
diff --git a/src/vieval/tools/wrapper/AzureGPTWrapper.py b/src/vieval/tools/wrapper/OpenAIWrapper.py
similarity index 73%
rename from src/vieval/tools/wrapper/AzureGPTWrapper.py
rename to src/vieval/tools/wrapper/OpenAIWrapper.py
index 6824959..a3f3597 100644
--- a/src/vieval/tools/wrapper/AzureGPTWrapper.py
+++ b/src/vieval/tools/wrapper/OpenAIWrapper.py
@@ -5,14 +5,14 @@
from .BaseWrapper import BaseWrapper
-class AzureGPTWrapper(BaseWrapper):
+class OpenAIWrapper(BaseWrapper):
def __init__(self, engine=None, generation_config=None):
- self.generation_config = generation_config
- self.model = openai.AzureOpenAI(
- azure_endpoint=os.getenv("AZURE_ENDPOINT"),
- api_key=os.getenv("AZURE_KEY"),
- api_version=os.getenv("AZURE_VERSION"),
+ generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
+ generation_config["frequency_penalty"] = generation_config.pop(
+ "repetition_penalty"
)
+ self.generation_config = generation_config
+ self.model = openai.OpenAI()
self.engine = engine
def __call__(self, prompts, return_probs=False):
@@ -24,10 +24,7 @@ def __call__(self, prompts, return_probs=False):
response = self.chat_completions_with_backoff(
model=self.engine,
messages=prompt,
- temperature=self.generation_config["temperature"],
- max_tokens=self.generation_config["max_new_tokens"],
- top_p=0.95,
- frequency_penalty=self.generation_config["repetition_penalty"],
+ **self.generation_config,
)
generations.append(response.choices[0].message.content)
diff --git a/src/vieval/tools/wrapper/TGIWrapper.py b/src/vieval/tools/wrapper/TGIWrapper.py
index 13adb87..d20de97 100644
--- a/src/vieval/tools/wrapper/TGIWrapper.py
+++ b/src/vieval/tools/wrapper/TGIWrapper.py
@@ -1,15 +1,22 @@
import torch
import backoff
import requests
+from transformers import AutoTokenizer
+import warnings
+import os
from .BaseWrapper import BaseWrapper
from ..utils.chat_template import apply_chat_template
class TGIWrapper(BaseWrapper):
- def __init__(self, api_endpoint, generation_config, template=""):
- self.api_endpoint = api_endpoint
+ def __init__(self, generation_config, template=""):
+ self.api_endpoint = os.getenv("TGI_ENDPOINT")
self.generation_config = generation_config
self.model_template = template
+ self.model_info = self.get_model_info()
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_info["model_id"], trust_remote_code=True
+ )
def __call__(self, prompts, return_probs=False):
generations = []
@@ -22,7 +29,7 @@ def __call__(self, prompts, return_probs=False):
{
"inputs": prompt,
"parameters": {
- "truncate": 1500,
+ "truncate": self.model_info["max_input_tokens"],
"details": True,
**self.generation_config,
},
@@ -51,47 +58,52 @@ def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = []
completions_logprobs = []
prompts = apply_chat_template(prompts, self.model_template)
+ # tokenized_prompts = self.tokenizer(prompts)["input_ids"]
+ # len_tokenized_prompts = [len(p) for p in tokenized_prompts]
for prompt, completion in zip(prompts, completions):
try:
- prompt_tokens = self.generate_with_backoff(
- {
- "inputs": prompt,
- "parameters": {
- "truncate": 1500,
- "decoder_input_details": True,
- "max_new_tokens": 1,
- },
- }
- )["details"]["prefill"]
- completion_w_prompt = self.generate_with_backoff(
- {
- "inputs": prompt + completion + "",
- "parameters": {
- "truncate": 1500,
- "decoder_input_details": True,
- "max_new_tokens": 1,
- },
- }
- )["details"]["prefill"]
+ for prompt, completion in zip(prompts, completions):
+ prompt_tokens = self.generate_with_backoff(
+ {
+ "inputs": prompt,
+ "parameters": {
+ "truncate": self.model_info["max_input_tokens"],
+ "decoder_input_details": True,
+ "max_new_tokens": 1,
+ },
+ }
+ )["details"]["prefill"]
+ completion_w_prompt = self.generate_with_backoff(
+ {
+ "inputs": prompt + completion + self.tokenizer.eos_token,
+ "parameters": {
+ "truncate": self.model_info["max_input_tokens"],
+ "decoder_input_details": True,
+ "max_new_tokens": 1,
+ },
+ }
+ )["details"]["prefill"]
except Exception as e:
print(e)
print(prompt)
raise e
- logprobs = torch.tensor(
- [
- list(
- map(
- lambda x: x["logprob"],
- completion_w_prompt[len(prompt_tokens) :],
- )
+ logprobs = [
+ list(
+ map(
+ lambda x: x["logprob"],
+ completion_w_prompt[len(prompt_tokens) :],
)
- ]
- )
+ )
+ ]
completions_logprobs.append(logprobs)
completions_num_tokens.append(len(logprobs[0]))
return completions_logprobs, completions_num_tokens
+ def get_model_info(self):
+ info = requests.get(self.api_endpoint + "/info", verify=False)
+ return info.json()
+
@backoff.on_exception(
backoff.expo, requests.exceptions.RequestException, max_tries=10
)
@@ -104,6 +116,6 @@ def generate_with_backoff(self, inputs):
def get_text_logprobs_tgi(self, res):
return (
[res["generated_text"]],
- [torch.tensor(list(map(lambda x: x["logprob"], res["details"]["tokens"])))],
+ [list(map(lambda x: x["logprob"], res["details"]["tokens"]))],
[res["details"]["generated_tokens"]],
)
diff --git a/src/vieval/tools/wrapper/VLLMWrapper.py b/src/vieval/tools/wrapper/VLLMWrapper.py
new file mode 100644
index 0000000..be5836a
--- /dev/null
+++ b/src/vieval/tools/wrapper/VLLMWrapper.py
@@ -0,0 +1,77 @@
+from vllm import LLM, SamplingParams
+from typing import Dict, List
+from .BaseWrapper import BaseWrapper
+from ..utils.chat_template import apply_chat_template
+from ..utils.model import get_model
+
+
+class VLLMWrapper(BaseWrapper):
+ def __init__(self, config, generation_config, template: Dict = None):
+ generation_config["max_tokens"] = generation_config.pop("max_new_tokens")
+ generation_config["frequency_penalty"] = generation_config.pop(
+ "repetition_penalty"
+ )
+ self.model = LLM(
+ model=config.model_name,
+ cpu_offload_gb=config.cpu_offload_gb,
+ dtype=config.dtype,
+ )
+ self.generation_config = SamplingParams(
+ **generation_config, logprobs=1, prompt_logprobs=0
+ )
+ self.model_template = template
+
+ def __call__(self, prompts, return_probs=False):
+ generations = []
+ generations_probs = []
+ num_generated_tokens = []
+ prompts = apply_chat_template(prompts, self.model_template)
+ try:
+ outputs = self.model.generate(prompts, self.generation_config)
+ for output in outputs:
+ generations.append(output.outputs[0].text)
+ generations_probs.append(
+ [
+ list(logprob.values())[0].logprob
+ for logprob in output.outputs[0].logprobs
+ ]
+ )
+ num_generated_tokens.append(len(output.outputs[0].logprobs))
+ except Exception as e:
+ print(prompts)
+ raise e
+ return generations, generations_probs, num_generated_tokens
+
+ def compute_logprob_and_length(self, prompts, completions):
+ tokenizer = self.model.get_tokenizer()
+ completions_num_tokens = []
+ completions_logprobs = []
+ prompts = apply_chat_template(prompts, self.model_template)
+ tokenized_prompts = tokenizer(prompts)["input_ids"]
+ len_tokenized_prompts = [len(p) for p in tokenized_prompts]
+ completed_prompts = [
+ prompt + completion + tokenizer.eos_token
+ for prompt, completion in zip(prompts, completions)
+ ]
+ outputs = self.model.generate(
+ completed_prompts,
+ SamplingParams(
+ max_tokens=1,
+ prompt_logprobs=0,
+ ignore_eos=False,
+ skip_special_tokens=False,
+ ),
+ )
+ for output, len_tokenized_prompt in zip(outputs, len_tokenized_prompts):
+ completions_num_tokens.append(
+ len(output.prompt_logprobs) - len_tokenized_prompt
+ )
+ completions_logprobs.append(
+ [
+ [
+ list(logprob.values())[0].logprob
+ for logprob in output.prompt_logprobs[len_tokenized_prompt:]
+ ]
+ ]
+ )
+ return completions_logprobs, completions_num_tokens
diff --git a/src/vieval/tools/wrapper/__init__.py b/src/vieval/tools/wrapper/__init__.py
index 10c39b2..ffe8892 100644
--- a/src/vieval/tools/wrapper/__init__.py
+++ b/src/vieval/tools/wrapper/__init__.py
@@ -1,6 +1,7 @@
-from .AzureGPTWrapper import AzureGPTWrapper
+from .OpenAIWrapper import OpenAIWrapper
from .GeminiWrapper import GeminiWrapper
from .TGIWrapper import TGIWrapper
from .HFWrapper import HFWrapper
+from .VLLMWrapper import VLLMWrapper
-__all__ = ["AzureGPTWrapper", "GeminiWrapper", "TGIWrapper", "HFWrapper"]
+__all__ = ["OpenAIWrapper", "GeminiWrapper", "TGIWrapper", "HFWrapper"]
From 6b52e19b05bd24c3bd5b6ad95f1c98877b359810 Mon Sep 17 00:00:00 2001
From: ledong0110 <74060032+ledong0110@users.noreply.github.com>
Date: Wed, 31 Jul 2024 14:19:57 +0700
Subject: [PATCH 2/3] fix: remove torch tensor in gemini wrapper
---
src/vieval/tools/wrapper/GeminiWrapper.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/src/vieval/tools/wrapper/GeminiWrapper.py b/src/vieval/tools/wrapper/GeminiWrapper.py
index c288b68..c71a625 100644
--- a/src/vieval/tools/wrapper/GeminiWrapper.py
+++ b/src/vieval/tools/wrapper/GeminiWrapper.py
@@ -1,4 +1,3 @@
-import torch
import json
import os
import openai
@@ -49,7 +48,7 @@ def __init__(self, model_name=None, generation_config=None):
def __call__(self, prompts, return_probs=False):
generations = []
- generations_probs = [torch.tensor([])] * len(prompts)
+ generations_probs = [[]] * len(prompts)
num_generated_tokens = []
for prompt in prompts:
processed_prompt = [list(p.values())[1] for p in prompt]
@@ -74,7 +73,7 @@ def __call__(self, prompts, return_probs=False):
def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = [0] * len(prompts)
- completions_logprobs = [torch.tensor([])] * len(prompts)
+ completions_logprobs = [[]] * len(prompts)
# Not Implement
return completions_logprobs, completions_num_tokens
From 995689ac2ba22794c21a73971827b4b84fa3d03f Mon Sep 17 00:00:00 2001
From: ledong0110 <74060032+ledong0110@users.noreply.github.com>
Date: Wed, 31 Jul 2024 14:21:13 +0700
Subject: [PATCH 3/3] fix: remove torch tensor in openai wrapper
---
src/vieval/tools/wrapper/OpenAIWrapper.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/vieval/tools/wrapper/OpenAIWrapper.py b/src/vieval/tools/wrapper/OpenAIWrapper.py
index a3f3597..d272c42 100644
--- a/src/vieval/tools/wrapper/OpenAIWrapper.py
+++ b/src/vieval/tools/wrapper/OpenAIWrapper.py
@@ -17,7 +17,7 @@ def __init__(self, engine=None, generation_config=None):
def __call__(self, prompts, return_probs=False):
generations = []
- generations_probs = [torch.tensor([])] * len(prompts)
+ generations_probs = [[]] * len(prompts)
num_generated_tokens = []
for prompt in prompts:
@@ -34,7 +34,7 @@ def __call__(self, prompts, return_probs=False):
def compute_logprob_and_length(self, prompts, completions):
completions_num_tokens = [0] * len(prompts)
- completions_logprobs = [torch.tensor([])] * len(prompts)
+ completions_logprobs = [[]] * len(prompts)
# TODO: Implement when OpenAI support logprobs of sentence
return completions_logprobs, completions_num_tokens