-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
71 lines (57 loc) · 1.95 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import random
from utils.load_data import load_data
from utils.model_registry import MODELS
from utils.predict import (
get_data_type,
get_mwp_prompt,
get_mgsm_prompt_base,
get_aqua_prompt,
predict_gsm8k,
predict_mgsm,
predict_aqua,
)
from utils.setup_models import setup_model_inner
from utils.save_output import save_json
import torch
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
def main():
algorithm = os.environ['ALGO']
data_name = os.environ['DATA_NAME']
model_name = os.environ['MODEL']
random.seed(2023)
output_name = f'{data_name}_{model_name}_{algorithm}'
print(f'{output_name = }')
testset = load_data(data_name)
data_type = get_data_type(data_name)
if data_type == 'mwp':
template = get_mwp_prompt()
elif data_type == 'aqua':
template = get_aqua_prompt()
elif data_type == 'mgsm':
language = data_name[-2:]
template = get_mgsm_prompt_base(language)
prefix = '\n\n'.join(template.split('\n\n')[:-1])
tokenizer, model, generate_callback = setup_model_inner(
algorithm, model_name, prefix=prefix
)
model.generation_config.pad_token_id = tokenizer.pad_token_id
if data_type == 'mwp':
outputs = predict_gsm8k(
model, MODELS[model_name], tokenizer, testset,
template, algorithm == 'vanilla', generate_callback,
)
elif data_type == 'aqua':
outputs = predict_aqua(
model, tokenizer, testset, template,
algorithm == 'vanilla', generate_callback,
)
elif data_type == 'mgsm':
outputs = predict_mgsm(
model, tokenizer, testset, template,
algorithm == 'vanilla', generate_callback,
)
save_json(outputs, output_name)
if __name__ == '__main__':
main()