-
Notifications
You must be signed in to change notification settings - Fork 13
/
main.py
42 lines (31 loc) · 1.33 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import sys
import ray
from data import load_dataset
from trainer import get_trainer
from utils.utils import print_eval_acc, print_train_acc, load_with_default_yaml, save_dict_as_one_line_csv
def main(working_dir, seed, train_epochs, eval_every, use_ray, ray_params, data_params, trainer_params):
if use_ray:
ray.init(**ray_params)
(train_iterator, test_iterator), metadata = load_dataset(**data_params)
trainer = get_trainer(seed=seed, train_iterator=train_iterator, test_iterator=test_iterator, metadata=metadata,
**trainer_params)
eval_metrics = trainer.evaluate()
print_eval_acc(eval_metrics)
for i in range(train_epochs):
train_metrics = trainer.train_epoch()
print_train_acc(train_metrics, epoch=i)
if eval_every is not None and (i + 1) % eval_every == 0:
eval_metrics = trainer.evaluate()
print_eval_acc(eval_metrics)
eval_metrics = trainer.evaluate()
print_eval_acc(eval_metrics)
if use_ray:
ray.shutdown()
metrics = dict(**train_metrics, **eval_metrics)
save_dict_as_one_line_csv(metrics, filename=os.path.join(working_dir, "metrics.csv"))
return metrics
if __name__ == "__main__":
param_path = sys.argv[1]
param_dict = load_with_default_yaml(path=param_path)
main(**param_dict)