We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
为啥推理一条需要接近1分钟,8*80G A100
import os from vllm import LLM, SamplingParams import torch from constants_prompt import build_autoj_input from zh_constants_prompt import zh_build_autoj_input import argparse import pandas as pd import time from tqdm import tqdm def extract_pariwise_result(raw_output): raw_output = raw_output.strip() pos = raw_output.rfind('final decision is ') pred_label = -1 if pos != -1: pred_rest = raw_output[pos + len('final decision is '):].strip().lower() if pred_rest.startswith('response 1'): pred_label = 0 elif pred_rest.startswith('response 2'): pred_label = 1 elif pred_rest.startswith('tie'): pred_label = 2 return pred_label def zh_extract_pariwise_result(raw_output): raw_output = raw_output.strip() pos = raw_output.rfind('最终决定是') pred_label = -1 if pos != -1: # pred_rest = raw_output[pos + len('final decision is '):].strip().lower() pred_rest = raw_output[pos + len('最终决定是'):].strip() if pred_rest.startswith('回应1'): pred_label = 0 elif pred_rest.startswith('回应2'): pred_label = 1 elif pred_rest.startswith('平局'): pred_label = 2 return pred_label def extract_single_rating(score_output): if "Rating: [[" in score_output: pos = score_output.rfind("Rating: [[") pos2 = score_output.find("]]", pos) assert pos != -1 and pos2 != -1 return float(score_output[pos + len("Rating: [["):pos2].strip()) else: return 0.0 def zh_extract_single_rating(score_output): if "评分:[[" in score_output: pos = score_output.rfind("评分:[[") pos2 = score_output.find("]]", pos) assert pos != -1 and pos2 != -1 return float(score_output[pos + len("评分:[["):pos2].strip()) elif "打分:[[" in score_output: pos = score_output.rfind("打分:[[") pos2 = score_output.find("]]", pos) assert pos != -1 and pos2 != -1 return float(score_output[pos + len("打分:[["):pos2].strip()) else: return 0.0 if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--language", type=str, default="Chinese", help="Choose Chinese or English evaluation") parser.add_argument("--data_dir", type=str, default="./", help="data path") parser.add_argument("--output_dir", type=str, default="./", help="output data path") parser.add_argument("--model", type=str, default="MODELS/autoj-bilingual-6b", help="output data path") parser.add_argument("--columns", nargs=2, type=str, default=["query", "ans"], help="columns of q and a") args = parser.parse_args() assert args.language == "Chinese" or args.language == "English" num_gpus = torch.cuda.device_count() model_name_or_dir = args.model # or "local path to auto-j" llm = LLM(model=model_name_or_dir, tensor_parallel_size=num_gpus,) sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=4096) ##读取数据集 df = pd.read_csv(args.data_dir) # 检查列名 assert all([col in df.columns for col in args.columns]),f"请指定QA列名,当前列名:{args.columns},不在文件中" l = df[args.columns].values.tolist() l2=[] # 每对qa拼接prompt inputs = [zh_build_autoj_input(prompt=i, resp1=j, protocol="zh_single") for i, j in l] # 整体生成 outputs = llm.generate(inputs, sampling_params) for input ,output in zip(l, outputs): query,ans = input judgment = output.outputs[0].text evaluation_result = zh_extract_single_rating(judgment) l2.append([query,ans, judgment, evaluation_result]) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) output_df = pd.DataFrame( l2, columns=['query', 'ans', 'judgment', 'evaluation_result']) _, file_name = os.path.split(args.output_dir) output_df.to_csv(os.path.join(args.output_dir, file_name), index=None) print(f'文件保存到{os.path.join(args.output_dir, file_name)}')
Processed prompts: 0%| | 0/366 [00:00<?, ?it/s] Processed prompts: 0%| | 1/366 [04:58<30:14:45, 298.32s/it] Processed prompts: 1%| | 2/366 [05:10<13:08:12, 129.93s/it] Processed prompts: 1%| | 3/366 [10:06<20:44:28, 205.70s/it] Processed prompts: 1%| | 4/366 [10:38<13:47:31, 137.16s/it]
The text was updated successfully, but these errors were encountered:
推理结果都是!!!
! ! �! ! ! ! �! �! �! �! �! ! ! �! ! ! ! ! ! ! �! �! ! �! �! �! ! ! �! ! ! ! ! ! ! ! ! !
Sorry, something went wrong.
No branches or pull requests
为啥推理一条需要接近1分钟,8*80G A100
The text was updated successfully, but these errors were encountered: