-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
121 lines (102 loc) · 3.44 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
from utils import get_mean_iou, get_pixel_accuracy
from tqdm import tqdm
def train(
epochs,
model,
train_loader,
val_loader,
criterion,
optimizer,
scheduler,
device,
patch=False,
):
model.to(device)
train_losses, val_losses, train_iou, val_iou, train_acc, val_acc = (
[],
[],
[],
[],
[],
[],
)
best_val_loss = np.inf
patience_counter = 0
for epoch in range(epochs):
model.train()
train_loss, train_iou_score, train_accuracy = 0.0, 0.0, 0.0
for data in tqdm(train_loader):
images, masks = data
if patch:
images = images.view(-1, *images.size()[2:])
masks = masks.view(-1, *masks.size()[2:])
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_iou_score += get_mean_iou(outputs, masks)
train_accuracy += get_pixel_accuracy(outputs, masks)
train_losses.append(train_loss / len(train_loader))
train_iou.append(train_iou_score / len(train_loader))
train_acc.append(train_accuracy / len(train_loader))
val_loss, val_iou_score, val_accuracy = validate(
model, val_loader, criterion, device, patch
)
val_losses.append(val_loss)
val_iou.append(val_iou_score)
val_acc.append(val_accuracy)
scheduler.step()
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), f"best_model_{epoch+1}.pth")
else:
patience_counter += 1
if patience_counter >= 6:
print(f"Early stopping triggered after {epoch+1} epochs.")
break
print_epoch_stats(
epoch,
epochs,
train_losses[-1],
val_losses[-1],
train_iou[-1],
val_iou[-1],
train_acc[-1],
val_acc[-1],
)
return {
"train_loss": train_losses,
"val_loss": val_losses,
"train_iou": train_iou,
"val_iou": val_iou,
"train_acc": train_acc,
"val_acc": val_acc,
}
def print_epoch_stats(
epoch, epochs, train_loss, val_loss, train_iou, val_iou, train_acc, val_acc
):
print(
f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}, "
f"Train IoU: {train_iou:.3f}, Val IoU: {val_iou:.3f}, Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}"
)
def validate(model, loader, criterion, device, patch):
model.eval()
total_loss, total_iou, total_acc = 0.0, 0.0, 0.0
with torch.no_grad():
for data in tqdm(loader):
images, masks = data
if patch:
images = images.view(-1, *images.size()[2:])
masks = masks.view(-1, *masks.size()[2:])
images, masks = images.to(device), masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
total_iou += get_mean_iou(outputs, masks)
total_acc += get_pixel_accuracy(outputs, masks)
return total_loss / len(loader), total_iou / len(loader), total_acc / len(loader)