forked from mlabonne/llm-autoeval
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
96 lines (78 loc) · 3.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import json
import logging
import os
import argparse
import time
from llm_autoeval.table import make_table, make_final_table
from llm_autoeval.upload import upload_to_github_gist
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_ID = os.getenv("MODEL_ID")
BENCHMARK = os.getenv("BENCHMARK")
GITHUB_API_TOKEN = os.getenv("GITHUB_API_TOKEN")
def main(directory: str, elapsed_time: float) -> None:
# Variables
tables = []
averages = []
# Tasks
if BENCHMARK == "openllm":
tasks = ["ARC", "HellaSwag", "MMLU", "TruthfulQA", "Winogrande", "GSM8K"]
elif BENCHMARK == "nous":
tasks = ["AGIEval", "GPT4All", "TruthfulQA", "Bigbench"]
elif BENCHMARK == "medqa":
tasks = ["medqa"]
elif BENCHMARK == "medmcqa":
tasks = ["medmcqa"]
elif BENCHMARK == "pubmedqa":
tasks = ["pubmedqa"]
elif BENCHMARK == "legalbench":
tasks = ["legalbench_issue_tasks", "legalbench_rule_tasks", "legalbench_conclusion_tasks", "legalbench_interpretation_tasks", "legalbench_rhetoric_tasks"]
else:
tasks = [BENCHMARK]
# raise NotImplementedError(f"BENCHMARK should be 'openllm' or 'nous' or 'legalbench' (current value = {BENCHMARK})")
# Load results
for task in tasks:
file_path = f"{directory}/{task.lower()}.json"
if os.path.exists(file_path):
json_data = open(file_path, "r").read()
data = json.loads(json_data, strict=False)
table, average = make_table(data, task)
else:
table = ""
average = "Error: File does not exist"
tables.append(table)
averages.append(average)
# Generate tables
summary = ""
for index, task in enumerate(tasks):
summary += f"### {task}\n{tables[index]}\nAverage: {averages[index]}%\n\n"
result_dict = {k: v for k, v in zip(tasks, averages)}
# Calculate the final average, excluding strings
if all(isinstance(e, float) for e in averages):
final_average = round(sum(averages) / len(averages), 2)
summary += f"Average score: {final_average}%"
result_dict.update({"Average": final_average})
else:
summary += "Average score: Not available due to errors"
# Add elapsed time
convert = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
summary += f"\n\nElapsed time: {convert}"
# Generate final table
final_table = make_final_table(result_dict, MODEL_ID)
summary = final_table + "\n" + summary
# Upload to GitHub Gist
upload_to_github_gist(
summary, f"{MODEL_ID.split('/')[-1]}-{BENCHMARK.capitalize()}.md", GITHUB_API_TOKEN
)
if __name__ == "__main__":
# Create the parser
parser = argparse.ArgumentParser(description="Summarize results and upload them.")
parser.add_argument("directory", type=str, help="The path to the directory with the JSON results")
parser.add_argument("elapsed_time", type=float, help="Elapsed time since the start of the evaluation")
# Parse the arguments
args = parser.parse_args()
# Check if the directory exists
if not os.path.isdir(args.directory):
raise ValueError(f"The directory {args.directory} does not exist.")
# Call the main function with the directory argument
main(args.directory, args.elapsed_time)