-
Notifications
You must be signed in to change notification settings - Fork 32
/
train.py
101 lines (76 loc) · 3.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torchvision
import torch.nn as nn
import numpy as np
import json
import utils
import validate
import argparse
import models.densenet
import models.resnet
import models.inception
import time
import dataloaders.datasetaug
import dataloaders.datasetnormal
from tqdm import tqdm
from tensorboardX import SummaryWriter
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str)
def train(model, device, data_loader, optimizer, loss_fn):
model.train()
loss_avg = utils.RunningAverage()
with tqdm(total=len(data_loader)) as t:
for batch_idx, data in enumerate(data_loader):
inputs = data[0].to(device)
target = data[1].squeeze(1).to(device)
outputs = model(inputs)
loss = loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_avg.update(loss.item())
t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
t.update()
return loss_avg()
def train_and_evaluate(model, device, train_loader, val_loader, optimizer, loss_fn, writer, params, split, scheduler=None):
best_acc = 0.0
for epoch in range(params.epochs):
avg_loss = train(model, device, train_loader, optimizer, loss_fn)
acc = validate.evaluate(model, device, val_loader)
print("Epoch {}/{} Loss:{} Valid Acc:{}".format(epoch, params.epochs, avg_loss, acc))
is_best = (acc > best_acc)
if is_best:
best_acc = acc
if scheduler:
scheduler.step()
utils.save_checkpoint({"epoch": epoch + 1,
"model": model.state_dict(),
"optimizer": optimizer.state_dict()}, is_best, split, "{}".format(params.checkpoint_dir))
writer.add_scalar("data{}/trainingLoss{}".format(params.dataset_name, split), avg_loss, epoch)
writer.add_scalar("data{}/valLoss{}".format(params.dataset_name, split), acc, epoch)
writer.close()
if __name__ == "__main__":
args = parser.parse_args()
params = utils.Params(args.config_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for i in range(1, params.num_folds+1):
if params.dataaug:
train_loader = dataloaders.datasetaug.fetch_dataloader( "{}training128mel{}.pkl".format(params.data_dir, i), params.dataset_name, params.batch_size, params.num_workers, 'train')
val_loader = dataloaders.datasetaug.fetch_dataloader("{}validation128mel{}.pkl".format(params.data_dir, i), params.dataset_name, params.batch_size, params.num_workers, 'validation')
else:
train_loader = dataloaders.datasetnormal.fetch_dataloader( "{}training128mel{}.pkl".format(params.data_dir, i), params.dataset_name, params.batch_size, params.num_workers)
val_loader = dataloaders.datasetnormal.fetch_dataloader("{}validation128mel{}.pkl".format(params.data_dir, i), params.dataset_name, params.batch_size, params.num_workers)
writer = SummaryWriter(comment=params.dataset_name)
if params.model=="densenet":
model = models.densenet.DenseNet(params.dataset_name, params.pretrained).to(device)
elif params.model=="resnet":
model = models.resnet.ResNet(params.dataset_name, params.pretrained).to(device)
elif params.model=="inception":
model = models.inception.Inception(params.dataset_name, params.pretrained).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay)
if params.scheduler:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)
else:
scheduler = None
train_and_evaluate(model, device, train_loader, val_loader, optimizer, loss_fn, writer, params, i, scheduler)