-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
90 lines (72 loc) · 2.3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import json
import time
import numpy as np
import pandas as pd
from pathlib import Path
from typing import TextIO
from fire import Fire
from tqdm import tqdm
from data_loading import ArgumentSample, ArgumentData
from modeling import select_model, EvalModel
from prompting import select_prompter, Prompter
from scoring import select_scorer
def inference(
model: EvalModel,
data_train: ArgumentData,
data_test: ArgumentData,
prompter: Prompter,
file: TextIO,
):
progress = tqdm(data_test.samples)
sample: ArgumentSample
targets = []
predictions = []
for _, sample in enumerate(progress):
k = int(len(data_train.samples))
prompt = prompter.run(data_train, sample)
# handle prompt length
while not model.check_valid_length(prompt) and k > 0:
k -= 1
data_train.samples = data_train.samples[:k]
prompt = prompter.run(data_train, sample)
if not model.check_valid_length(prompt):
prompt = model.truncate_input(prompt)
# predict
sample.prompt = prompt
sample.raw = model.run(prompt)
sample.pred = prompter.get_answer(sample.raw)
print(sample.model_dump_json(), file=file)
targets.append(sample.tgt)
predictions.append(sample.pred)
return predictions, targets
def main(
task: str = "conclugen",
data_name: str = "base",
num_train: int = 5,
seed: int = 0,
**kwargs
):
# load model
model = select_model(**kwargs)
print(locals())
# select prompter
prompter = select_prompter(task, data_name)
# load data
data_train, data_test = ArgumentData.load(task, data_name, num_train, seed)
# set path
output_folder = f"output/{task}/{data_name}/{num_train}_shot/seed_{seed}"
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
model_name = Path(model.path_model).stem
output_path = f"{output_folder}/{model_name}.json"
# infer
Path(output_path).parent.mkdir(exist_ok=True, parents=True)
with open(output_path, "w") as file:
targets, predictions = inference(model, data_train, data_test, prompter, file)
# score
scorer = select_scorer(task)
scores = scorer.run(predictions, targets)
print(scores)
if __name__ == "__main__":
Fire(main)