-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
103 lines (76 loc) · 3.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
import torch.nn as nn
import yaml
import os
import argparse
from torch.utils.tensorboard import SummaryWriter
# Custom imports
import sys
sys.path.append('models')
import model_factory
import loss_factory
import optimizer_factory
import train_loops.prior_train as prior_train
from train_loops.train_control import train_control
import test
import datasets
import utils
def load_config(config_path):
with open(config_path, 'r') as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
raise ValueError('Error loading config file.')
return config
def model2device(full_package, config):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
full_package['model'] = nn.DataParallel(full_package['model'])
config['train']['batch_size'] = config['train']['batch_size']*torch.cuda.device_count()
config['test']['batch_size'] = config['test']['batch_size']*torch.cuda.device_count()
full_package['model'] = full_package['model'].to(device)
full_package['device'] = device
return full_package, config
def train_eval_loop(full_package, config):
# Training loop
print('==> Start training..')
full_package['it_global'] = 0
full_package['epoch'] = 0
while full_package['epoch'] < config['train']['epochs'] or full_package['it_global'] < config['train']['global_iteration']:
# train loop
train_loss = train_control(full_package, config['train']['loop_type'])
# test loop
acc = test.test_control(full_package, config['test']['loop_type'])
full_package['scheduler'].step()
print(f"Epoch {full_package['epoch']+ 1}, test Accuracy: {acc:.4f}")
if config['train']['debug'] and full_package['epoch'] >= 5: break
print('Finished Training')
def main(config):
full_package = {'config': config}
# Define the dataset and dataloader.
# TODO different types of loaders: noise loader
datasets.get_loaders(config,full_package)
prior_train.prior_train_control(full_package)
# Define the model, loss function, and optimizer.
# TODO: Deal with loading pretrained models
model_factory.get_model(config['model']['name'], pretrained=True, num_classes=config['data']['num_classes'],full_package=full_package)
full_package, config = model2device(full_package, config)
# TODO: optmizer lr boundaries
# TODO: [Noisy] more loss functions
loss_factory.get_loss_function(config['train']['loss_type'],full_package)
optimizer_factory.get_optimizer(full_package)
# TODO: deal with scheduler boundaries
optimizer_factory.get_scheduler(full_package)
train_eval_loop(full_package, config)
prior_train.post_train(full_package)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the config file.')
args = parser.parse_args()
config = utils.start_program(args)
np.random.RandomState(seed=config['general']['np_seed'])
torch.manual_seed(config['general']['torch_seed'])
main(config)