forked from THUDM/AlignBench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_answers.py
105 lines (94 loc) · 3.6 KB
/
get_answers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import jsonlines
import os
import json
from inference.models import get_model_api
from inference.utils import test_api_alive
if __name__ == '__main__':
"""
singleround inference
input question doc format:
question_doc = {
"question_id": int,
"category": str,
"subcategory": str,
"question": str,
}
output answer file format
{
"question_id": int,
"category": str,
"subcategory": str,
"model_id": str,
"question": str,
"answer": str
}
"""
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="Evaluated Model Name")
parser.add_argument("--repredict", type=bool, default=True, help="Repredict When Encounter Empty Answers")
parser.add_argument("--workers", type=int)
parser.add_argument("--question-file", type=str)
parser.add_argument("--save-dir", type=str)
parser.add_argument("--temperature-config-file", type=str, default="config/temperature.json", help="Temperature Config")
parser.add_argument("--first-n", type=int, help="Debug Option")
args = parser.parse_args()
## test api status
print(">>> testing whether the api is alive ...")
res = test_api_alive(args.model)
if res:
print(f">>> api {args.model} is alive")
model = get_model_api(args.model, args.workers)
print(">>> inference model: ", args.model)
## load questions
docs = []
with jsonlines.open(args.question_file, "r") as f:
for doc in f:
docs.append(doc)
f.close()
if args.first_n:
docs = docs[: args.first_n]
print(f">>> running {len(docs)} docs")
## load temperature configs
with open(args.temperature_config_file, "r") as f:
temp_config = json.loads(f.read())
temp_config["default"] = 0.7
f.close()
## use subcategory to get temperature for each sample
samples = []
for doc in docs:
if doc.get("category", None) not in temp_config:
print(">>> warning: category not in temp_config, category: ", doc.get("category", None))
samples.append({
"question": doc["question"],
"temperature": temp_config.get(doc["category"], temp_config["default"])
})
outputs = model.generate_text(samples)
if args.repredict:
for process in range(3):
empty_ids = []
empty_samples = []
for index, output in enumerate(outputs):
if output is None or output == "":
empty_samples.append(samples[index])
empty_ids.append(index)
if len(empty_samples) == 0:
print(">>> no empty outputs, inference finished")
break
print(f">>> repredicting empty {len(empty_ids)} docs in repredict process {process} / 3")
empty_outputs = model.generate_text(empty_samples)
for index, output in zip(empty_ids, empty_outputs):
if output is not None and output != "":
outputs[index] = output
os.makedirs(args.save_dir, exist_ok=True)
save_path = os.path.join(args.save_dir, f"{args.model}.jsonl")
with jsonlines.open(save_path, 'w') as f:
for doc, output in zip(docs, outputs):
doc["model_id"] = args.model
doc["answer_id"] = str(doc["question_id"]) + "_" + args.model
if output is not None:
doc["answer"] = output
else:
doc["answer"] = None
f.write(doc)
f.close()