-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
113 lines (94 loc) · 3.69 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
import torch
from torcheval.metrics.functional import peak_signal_noise_ratio
from generate import generate_images
def train_epoch(model, optimizer, criterion, train_dataloader, device, epoch=0,
log_interval=50):
model.train()
total_psnr, total_count = 0, 0
losses = []
start_time = time.time()
for idx, (inputs, labels) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
predictions = model(inputs)
# compute loss
loss = criterion(predictions, labels)
losses.append(loss.item())
# backward
loss.backward()
optimizer.step()
total_psnr += peak_signal_noise_ratio(predictions, labels)
total_count += 1
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print(
"| epoch {:3d} | {:5d}/{:5d} batches "
"| psnr {:8.3f}".format(
epoch, idx, len(train_dataloader), total_psnr / total_count
)
)
total_psnr, total_count = 0, 0
start_time = time.time()
epoch_psnr = total_psnr / total_count
epoch_loss = sum(losses) / len(losses)
return epoch_psnr, epoch_loss
def evaluate_epoch(model, criterion, valid_dataloader, device):
model.eval()
total_psnr, total_count = 0, 0
losses = []
with torch.no_grad():
for idx, (inputs, labels) in enumerate(valid_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
predictions = model(inputs)
loss = criterion(predictions, labels)
losses.append(loss.item())
total_psnr += peak_signal_noise_ratio(predictions, labels)
total_count += 1
epoch_psnr = total_psnr / total_count
epoch_loss = sum(losses) / len(losses)
return epoch_psnr, epoch_loss
def train_model(model, model_name, save_model, optimizer, criterion, train_dataloader, valid_dataloader, num_epochs, device):
train_psnrs, train_losses = [], []
eval_psnrs, eval_losses = [], []
best_psnr_eval = -1000
times = []
for epoch in range(1, num_epochs+1):
epoch_start_time = time.time()
# Training
train_psnr, train_loss = train_epoch(model, optimizer, criterion, train_dataloader, device, epoch)
train_psnrs.append(train_psnr.cpu())
train_losses.append(train_loss)
# Evaluation
eval_psnr, eval_loss = evaluate_epoch(model, criterion, valid_dataloader, device)
eval_psnrs.append(eval_psnr.cpu())
eval_losses.append(eval_loss)
# Save best model
if best_psnr_eval < eval_psnr :
torch.save(model.state_dict(), save_model + f'/{model_name}.pt')
inputs_t, targets_t = next(iter(valid_dataloader))
generate_images(model, inputs_t, targets_t)
best_psnr_eval = eval_psnr
times.append(time.time() - epoch_start_time)
# Print loss, psnr end epoch
print("-" * 59)
print(
"| End of epoch {:3d} | Time: {:5.2f}s | Train psnr {:8.3f} | Train Loss {:8.3f} "
"| Valid psnr {:8.3f} | Valid Loss {:8.3f} ".format(
epoch, time.time() - epoch_start_time, train_psnr, train_loss, eval_psnr, eval_loss
)
)
print("-" * 59)
# Load best model
model.load_state_dict(torch.load(save_model + f'/{model_name}.pt'))
model.eval()
metrics = {
'train_psnr': train_psnrs,
'train_loss': train_losses,
'valid_psnr': eval_psnrs,
'valid_loss': eval_losses,
'time': times
}
return model, metrics