From 897fbac42f381a13a5dc3f0ad1e93de466ce3d42 Mon Sep 17 00:00:00 2001 From: ledong0110 <74060032+ledong0110@users.noreply.github.com> Date: Wed, 31 Jul 2024 02:03:41 +0700 Subject: [PATCH 1/3] feat: add VLLM wrapper --- config/llm_template.json | 4 + env.template | 12 ++- pyproject.toml | 1 + src/vieval/script_arguments.py | 8 +- src/vieval/tools/pipelines/pipelines.py | 47 ++++++----- src/vieval/tools/wrapper/HFWrapper.py | 14 ++-- .../{AzureGPTWrapper.py => OpenAIWrapper.py} | 17 ++-- src/vieval/tools/wrapper/TGIWrapper.py | 78 +++++++++++-------- src/vieval/tools/wrapper/VLLMWrapper.py | 77 ++++++++++++++++++ src/vieval/tools/wrapper/__init__.py | 5 +- 10 files changed, 182 insertions(+), 81 deletions(-) rename src/vieval/tools/wrapper/{AzureGPTWrapper.py => OpenAIWrapper.py} (73%) create mode 100644 src/vieval/tools/wrapper/VLLMWrapper.py diff --git a/config/llm_template.json b/config/llm_template.json index ab39672..726cb7f 100644 --- a/config/llm_template.json +++ b/config/llm_template.json @@ -7,6 +7,10 @@ "template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '' + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + '' }}{% endif %}{% endfor %}", "system_prompt": true }, + "llama-3": { + "template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = '<|begin_of_text|>' + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}", + "system_prompt": false + }, "mistral": { "template": "''+{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content']}}+''+{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", "system_prompt": true diff --git a/env.template b/env.template index 42d1700..6768f04 100644 --- a/env.template +++ b/env.template @@ -1,7 +1,11 @@ #GPT -AZURE_KEY=xxxx123 -AZURE_VERSION=2023-07-01-preview -AZURE_ENDPOINT=http://gpt-4.com +OPENAI_API_TYPE="azure" +OPENAI_API_BASE="https://" +OPENAI_API_KEY="your AzureOpenAI key" +OPENAI_API_VERSION="2023-05-15" + +#TGI +TGI_ENDPOINT="" #GEMINI -GEMINI_KEY=abcdefghtyuidkg \ No newline at end of file +GEMINI_KEY="abcdefghtyuidkg" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index edef8bd..dca667d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ license = { text = "Apache 2.0 License" } dependencies = [ + "vllm>=0.5.2", "accelerate>=0.30.1", "peft>=0.11.1", "bitsandbytes>=0.40.2", diff --git a/src/vieval/script_arguments.py b/src/vieval/script_arguments.py index 98ab7ca..475559e 100644 --- a/src/vieval/script_arguments.py +++ b/src/vieval/script_arguments.py @@ -90,7 +90,9 @@ class ScriptArguments: default=True, metadata={"help": "Enable the TF32 mode (available in Ampere and newer GPUs)"}, ) - + cpu_offload_gb: Optional[bool] = field( + default=0, metadata={"help": "Offload computation to CPU"} + ) auto_find_batch_size: Optional[bool] = field( default=True, metadata={"help": "Enable auto batch size"} ) @@ -167,6 +169,9 @@ class ScriptArguments: ) # Inference parameters + dtype: Optional[str] = field( + default="float16", metadata={"help": "Data type for model weights"} # float16 + ) smoke_test: Optional[bool] = field( default=False, metadata={"help": "Run a smoke test on a small dataset"} ) @@ -182,7 +187,6 @@ class ScriptArguments: cot: Optional[bool] = field( default=False, metadata={"help": "Enable chain of thought when prompting MATH"} ) - tgi: Optional[str] = field(default="", metadata={"help": "Embed TGI endpoint"}) seed: Optional[int] = field(default=42, metadata={"help": "Random seed"}) continue_infer: Optional[bool] = field( default=False, diff --git a/src/vieval/tools/pipelines/pipelines.py b/src/vieval/tools/pipelines/pipelines.py index 4d725e2..e19e6f4 100644 --- a/src/vieval/tools/pipelines/pipelines.py +++ b/src/vieval/tools/pipelines/pipelines.py @@ -4,7 +4,7 @@ import json from tqdm import tqdm from ..utils.model import get_model -from ..wrapper import AzureGPTWrapper, TGIWrapper, GeminiWrapper, HFWrapper +from ..wrapper import OpenAIWrapper, TGIWrapper, GeminiWrapper, HFWrapper from ..utils.utils import * from ..utils.metric_utils import info_from_filename from .metric_pipelines import MetricPipeline @@ -28,24 +28,23 @@ def __init__(self, task, config): # print(config.tgi) if config.wtype == "tgi": self.infer_pipeline = TGIWrapper( - api_endpoint=config.tgi, generation_config=GenerationConfig[extract_task], template=LLM_TEMPLATE[config.ptemplate], ) elif config.wtype == "hf": - # Load model - self.model, self.tokenizer = get_model(config=config) - self.model.eval() - self.infer_pipeline = HFWrapper( - model=self.model, - tokenizer=self.tokenizer, + config=config, generation_config=GenerationConfig[extract_task], template=LLM_TEMPLATE[config.ptemplate], ) - - elif config.wtype == "azuregpt": - self.infer_pipeline = AzureGPTWrapper( + elif config.wtype == "vllm": + self.infer_pipeline = VLLMWrapper( + config=config, + generation_config=GenerationConfig[extract_task], + template=LLM_TEMPLATE[config.ptemplate], + ) + elif config.wtype == "openai": + self.infer_pipeline = OpenAIWrapper( engine=config.model_name, generation_config=GenerationConfig[extract_task], ) @@ -146,7 +145,7 @@ def __question_answering(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) predictions.extend(results) references.extend([x[0] for x in batch[ds_wrapper.answer]["text"]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) idx += 1 if idx % 100 == 0: @@ -185,7 +184,7 @@ def __question_answering_without_context( calib_probs = [] idx = 0 original_few_shot = [] - calibration_few_shot = [] + calib_few_shot = [] selected_sample = [] if self.continue_infer_data is not None: predictions.extend(self.continue_infer_data["predictions"]) @@ -258,8 +257,8 @@ def preprocessing_a_record(rec): ) predictions.extend(results) references.extend([x for x in batch[ds_wrapper.answer]]) - generation_probs.extend([x.tolist() for x in logprobs]) - calib_probs.extend([x.tolist() for x in calibprob_batch]) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) idx += 1 if idx % 100 == 0: print(f"Saving results of {idx} batches") @@ -327,7 +326,7 @@ def __summarization(self, ds_wrapper, ds_loader, saving_fn, start_idx=0): results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) predictions.extend(results) references.extend([x for x in batch[ds_wrapper.summarized_text]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) idx += 1 if idx % 100 == 0: @@ -451,7 +450,7 @@ def preprocessing_a_record(rec): ) predictions.extend(results) references.extend([x.item() for x in batch[ds_wrapper.label]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) option_probs.extend( [ [ @@ -623,7 +622,7 @@ def format_calib_fewshot(rec): for x in batch[ds_wrapper.label] ] ) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) option_probs.extend( [ [ @@ -756,7 +755,7 @@ def preprocessing_a_record(rec): ) predictions.extend(results) references.extend([x.item() for x in batch[ds_wrapper.label]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) option_probs.extend( [ [ @@ -932,7 +931,7 @@ def preprocessing_a_record(rec): ] ) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) option_probs.extend(opt_calib_out) idx += 1 if idx % 100 == 0: @@ -1017,7 +1016,7 @@ def preprocessing_a_record(rec): results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) predictions.extend(results) references.extend([x for x in batch[ds_wrapper.target]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) idx += 1 if idx % 100 == 0: @@ -1322,8 +1321,8 @@ def preprocessing_a_record(rec): ) predictions.extend(results) references.extend([x for x in batch[target]]) - generation_probs.extend([x.tolist() for x in logprobs]) - calib_probs.extend([x.tolist() for x in calibprob_batch]) + generation_probs.extend(logprobs) + calib_probs.extend(calibprob_batch) if sub_task == "math": math_problem_type.extend([x for x in batch[ds_wrapper.type]]) idx += 1 @@ -1417,7 +1416,7 @@ def preprocessing_a_record(rec): results, logprobs, _ = self.infer_pipeline(prompts, return_probs=True) predictions.extend(results) references.extend([x for x in batch[ds_wrapper.target_language]]) - generation_probs.extend([x.tolist() for x in logprobs]) + generation_probs.extend(logprobs) idx += 1 if idx % 100 == 0: diff --git a/src/vieval/tools/wrapper/HFWrapper.py b/src/vieval/tools/wrapper/HFWrapper.py index 4aeb27b..581be80 100644 --- a/src/vieval/tools/wrapper/HFWrapper.py +++ b/src/vieval/tools/wrapper/HFWrapper.py @@ -1,12 +1,14 @@ import torch from .BaseWrapper import BaseWrapper from ..utils.chat_template import apply_chat_template +from ..utils.model import get_model class HFWrapper(BaseWrapper): - def __init__(self, model, tokenizer, generation_config, template=""): - self.model = model - self.tokenizer = tokenizer + def __init__(self, config, generation_config, template=None): + self.model, self.tokenizer = get_model(config=config) + self.model.eval() + self.generation_config = generation_config self.model_template = template @@ -51,7 +53,7 @@ def __call__(self, prompts, return_probs=False): scores=generate_dict.scores, normalize_logits=True, ) - generations_probs.extend(generation_probs.cpu().numpy()) + generations_probs.extend(generation_probs.cpu().numpy().tolist()) return generations, generations_probs, num_generated_tokens @@ -67,7 +69,7 @@ def compute_logprob_and_length(self, prompts, completions): prompt_num_tokens = prompt_tokens.input_ids.shape[1] - 1 completion_tokens = self.tokenizer( - f"{completion} {self.tokenizer.eos_token}", return_tensors="pt" + f"{completion}{self.tokenizer.eos_token}", return_tensors="pt" ).to( self.model.device ) # SPIECE_UNDERLINE [tokens] SPIECE_UNDERLINE @@ -101,5 +103,5 @@ def compute_logprob_and_length(self, prompts, completions): ), ).squeeze(-1) # >>> batch_size, sequence_length - completions_logprobs.append(logprobs.cpu().numpy()) + completions_logprobs.append(logprobs.cpu().numpy().tolist()) return completions_logprobs, completions_num_tokens diff --git a/src/vieval/tools/wrapper/AzureGPTWrapper.py b/src/vieval/tools/wrapper/OpenAIWrapper.py similarity index 73% rename from src/vieval/tools/wrapper/AzureGPTWrapper.py rename to src/vieval/tools/wrapper/OpenAIWrapper.py index 6824959..a3f3597 100644 --- a/src/vieval/tools/wrapper/AzureGPTWrapper.py +++ b/src/vieval/tools/wrapper/OpenAIWrapper.py @@ -5,14 +5,14 @@ from .BaseWrapper import BaseWrapper -class AzureGPTWrapper(BaseWrapper): +class OpenAIWrapper(BaseWrapper): def __init__(self, engine=None, generation_config=None): - self.generation_config = generation_config - self.model = openai.AzureOpenAI( - azure_endpoint=os.getenv("AZURE_ENDPOINT"), - api_key=os.getenv("AZURE_KEY"), - api_version=os.getenv("AZURE_VERSION"), + generation_config["max_tokens"] = generation_config.pop("max_new_tokens") + generation_config["frequency_penalty"] = generation_config.pop( + "repetition_penalty" ) + self.generation_config = generation_config + self.model = openai.OpenAI() self.engine = engine def __call__(self, prompts, return_probs=False): @@ -24,10 +24,7 @@ def __call__(self, prompts, return_probs=False): response = self.chat_completions_with_backoff( model=self.engine, messages=prompt, - temperature=self.generation_config["temperature"], - max_tokens=self.generation_config["max_new_tokens"], - top_p=0.95, - frequency_penalty=self.generation_config["repetition_penalty"], + **self.generation_config, ) generations.append(response.choices[0].message.content) diff --git a/src/vieval/tools/wrapper/TGIWrapper.py b/src/vieval/tools/wrapper/TGIWrapper.py index 13adb87..d20de97 100644 --- a/src/vieval/tools/wrapper/TGIWrapper.py +++ b/src/vieval/tools/wrapper/TGIWrapper.py @@ -1,15 +1,22 @@ import torch import backoff import requests +from transformers import AutoTokenizer +import warnings +import os from .BaseWrapper import BaseWrapper from ..utils.chat_template import apply_chat_template class TGIWrapper(BaseWrapper): - def __init__(self, api_endpoint, generation_config, template=""): - self.api_endpoint = api_endpoint + def __init__(self, generation_config, template=""): + self.api_endpoint = os.getenv("TGI_ENDPOINT") self.generation_config = generation_config self.model_template = template + self.model_info = self.get_model_info() + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_info["model_id"], trust_remote_code=True + ) def __call__(self, prompts, return_probs=False): generations = [] @@ -22,7 +29,7 @@ def __call__(self, prompts, return_probs=False): { "inputs": prompt, "parameters": { - "truncate": 1500, + "truncate": self.model_info["max_input_tokens"], "details": True, **self.generation_config, }, @@ -51,47 +58,52 @@ def compute_logprob_and_length(self, prompts, completions): completions_num_tokens = [] completions_logprobs = [] prompts = apply_chat_template(prompts, self.model_template) + # tokenized_prompts = self.tokenizer(prompts)["input_ids"] + # len_tokenized_prompts = [len(p) for p in tokenized_prompts] for prompt, completion in zip(prompts, completions): try: - prompt_tokens = self.generate_with_backoff( - { - "inputs": prompt, - "parameters": { - "truncate": 1500, - "decoder_input_details": True, - "max_new_tokens": 1, - }, - } - )["details"]["prefill"] - completion_w_prompt = self.generate_with_backoff( - { - "inputs": prompt + completion + "", - "parameters": { - "truncate": 1500, - "decoder_input_details": True, - "max_new_tokens": 1, - }, - } - )["details"]["prefill"] + for prompt, completion in zip(prompts, completions): + prompt_tokens = self.generate_with_backoff( + { + "inputs": prompt, + "parameters": { + "truncate": self.model_info["max_input_tokens"], + "decoder_input_details": True, + "max_new_tokens": 1, + }, + } + )["details"]["prefill"] + completion_w_prompt = self.generate_with_backoff( + { + "inputs": prompt + completion + self.tokenizer.eos_token, + "parameters": { + "truncate": self.model_info["max_input_tokens"], + "decoder_input_details": True, + "max_new_tokens": 1, + }, + } + )["details"]["prefill"] except Exception as e: print(e) print(prompt) raise e - logprobs = torch.tensor( - [ - list( - map( - lambda x: x["logprob"], - completion_w_prompt[len(prompt_tokens) :], - ) + logprobs = [ + list( + map( + lambda x: x["logprob"], + completion_w_prompt[len(prompt_tokens) :], ) - ] - ) + ) + ] completions_logprobs.append(logprobs) completions_num_tokens.append(len(logprobs[0])) return completions_logprobs, completions_num_tokens + def get_model_info(self): + info = requests.get(self.api_endpoint + "/info", verify=False) + return info.json() + @backoff.on_exception( backoff.expo, requests.exceptions.RequestException, max_tries=10 ) @@ -104,6 +116,6 @@ def generate_with_backoff(self, inputs): def get_text_logprobs_tgi(self, res): return ( [res["generated_text"]], - [torch.tensor(list(map(lambda x: x["logprob"], res["details"]["tokens"])))], + [list(map(lambda x: x["logprob"], res["details"]["tokens"]))], [res["details"]["generated_tokens"]], ) diff --git a/src/vieval/tools/wrapper/VLLMWrapper.py b/src/vieval/tools/wrapper/VLLMWrapper.py new file mode 100644 index 0000000..be5836a --- /dev/null +++ b/src/vieval/tools/wrapper/VLLMWrapper.py @@ -0,0 +1,77 @@ +from vllm import LLM, SamplingParams +from typing import Dict, List +from .BaseWrapper import BaseWrapper +from ..utils.chat_template import apply_chat_template +from ..utils.model import get_model + + +class VLLMWrapper(BaseWrapper): + def __init__(self, config, generation_config, template: Dict = None): + generation_config["max_tokens"] = generation_config.pop("max_new_tokens") + generation_config["frequency_penalty"] = generation_config.pop( + "repetition_penalty" + ) + self.model = LLM( + model=config.model_name, + cpu_offload_gb=config.cpu_offload_gb, + dtype=config.dtype, + ) + self.generation_config = SamplingParams( + **generation_config, logprobs=1, prompt_logprobs=0 + ) + self.model_template = template + + def __call__(self, prompts, return_probs=False): + generations = [] + generations_probs = [] + num_generated_tokens = [] + prompts = apply_chat_template(prompts, self.model_template) + try: + outputs = self.model.generate(prompts, self.generation_config) + for output in outputs: + generations.append(output.outputs[0].text) + generations_probs.append( + [ + list(logprob.values())[0].logprob + for logprob in output.outputs[0].logprobs + ] + ) + num_generated_tokens.append(len(output.outputs[0].logprobs)) + except Exception as e: + print(prompts) + raise e + return generations, generations_probs, num_generated_tokens + + def compute_logprob_and_length(self, prompts, completions): + tokenizer = self.model.get_tokenizer() + completions_num_tokens = [] + completions_logprobs = [] + prompts = apply_chat_template(prompts, self.model_template) + tokenized_prompts = tokenizer(prompts)["input_ids"] + len_tokenized_prompts = [len(p) for p in tokenized_prompts] + completed_prompts = [ + prompt + completion + tokenizer.eos_token + for prompt, completion in zip(prompts, completions) + ] + outputs = self.model.generate( + completed_prompts, + SamplingParams( + max_tokens=1, + prompt_logprobs=0, + ignore_eos=False, + skip_special_tokens=False, + ), + ) + for output, len_tokenized_prompt in zip(outputs, len_tokenized_prompts): + completions_num_tokens.append( + len(output.prompt_logprobs) - len_tokenized_prompt + ) + completions_logprobs.append( + [ + [ + list(logprob.values())[0].logprob + for logprob in output.prompt_logprobs[len_tokenized_prompt:] + ] + ] + ) + return completions_logprobs, completions_num_tokens diff --git a/src/vieval/tools/wrapper/__init__.py b/src/vieval/tools/wrapper/__init__.py index 10c39b2..ffe8892 100644 --- a/src/vieval/tools/wrapper/__init__.py +++ b/src/vieval/tools/wrapper/__init__.py @@ -1,6 +1,7 @@ -from .AzureGPTWrapper import AzureGPTWrapper +from .OpenAIWrapper import OpenAIWrapper from .GeminiWrapper import GeminiWrapper from .TGIWrapper import TGIWrapper from .HFWrapper import HFWrapper +from .VLLMWrapper import VLLMWrapper -__all__ = ["AzureGPTWrapper", "GeminiWrapper", "TGIWrapper", "HFWrapper"] +__all__ = ["OpenAIWrapper", "GeminiWrapper", "TGIWrapper", "HFWrapper"] From 6b52e19b05bd24c3bd5b6ad95f1c98877b359810 Mon Sep 17 00:00:00 2001 From: ledong0110 <74060032+ledong0110@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:19:57 +0700 Subject: [PATCH 2/3] fix: remove torch tensor in gemini wrapper --- src/vieval/tools/wrapper/GeminiWrapper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/vieval/tools/wrapper/GeminiWrapper.py b/src/vieval/tools/wrapper/GeminiWrapper.py index c288b68..c71a625 100644 --- a/src/vieval/tools/wrapper/GeminiWrapper.py +++ b/src/vieval/tools/wrapper/GeminiWrapper.py @@ -1,4 +1,3 @@ -import torch import json import os import openai @@ -49,7 +48,7 @@ def __init__(self, model_name=None, generation_config=None): def __call__(self, prompts, return_probs=False): generations = [] - generations_probs = [torch.tensor([])] * len(prompts) + generations_probs = [[]] * len(prompts) num_generated_tokens = [] for prompt in prompts: processed_prompt = [list(p.values())[1] for p in prompt] @@ -74,7 +73,7 @@ def __call__(self, prompts, return_probs=False): def compute_logprob_and_length(self, prompts, completions): completions_num_tokens = [0] * len(prompts) - completions_logprobs = [torch.tensor([])] * len(prompts) + completions_logprobs = [[]] * len(prompts) # Not Implement return completions_logprobs, completions_num_tokens From 995689ac2ba22794c21a73971827b4b84fa3d03f Mon Sep 17 00:00:00 2001 From: ledong0110 <74060032+ledong0110@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:21:13 +0700 Subject: [PATCH 3/3] fix: remove torch tensor in openai wrapper --- src/vieval/tools/wrapper/OpenAIWrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vieval/tools/wrapper/OpenAIWrapper.py b/src/vieval/tools/wrapper/OpenAIWrapper.py index a3f3597..d272c42 100644 --- a/src/vieval/tools/wrapper/OpenAIWrapper.py +++ b/src/vieval/tools/wrapper/OpenAIWrapper.py @@ -17,7 +17,7 @@ def __init__(self, engine=None, generation_config=None): def __call__(self, prompts, return_probs=False): generations = [] - generations_probs = [torch.tensor([])] * len(prompts) + generations_probs = [[]] * len(prompts) num_generated_tokens = [] for prompt in prompts: @@ -34,7 +34,7 @@ def __call__(self, prompts, return_probs=False): def compute_logprob_and_length(self, prompts, completions): completions_num_tokens = [0] * len(prompts) - completions_logprobs = [torch.tensor([])] * len(prompts) + completions_logprobs = [[]] * len(prompts) # TODO: Implement when OpenAI support logprobs of sentence return completions_logprobs, completions_num_tokens