-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
151 lines (123 loc) · 4.9 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Trainer module"""
import numpy as np
from typing import Dict, Any
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
def train(**inputs: Dict[str, Any]) -> None:
"""trains the model given the inputs
Parameters:
-----------
model: torch.nn.Module
the model to train
criterion: torch.nn.Loss
the loss function
optimizer: torch.optim.Optimizer
the optimizer
trainloader: torch.utils.data.DataLoader
the training data loader
validloader: torch.utils.data.DataLoader
the validation data loader
epochs: int
number of epochs to train
device: str
device to train on
scheduler: torch.optim.lr_scheduler
the learning rate scheduler
training_losses: List[float]
list to store training losses
validation_losses: List[float]
list to store validation losses
"""
model = inputs["model"]
if model is None:
raise ValueError("model cannot be None")
criterion = inputs["criterion"]
optimizer = inputs["optimizer"]
trainloader = inputs["trainloader"]
validloader = inputs["validloader"]
epochs = inputs["epochs"]
device = inputs["device"]
scheduler = inputs.get("scheduler")
training_losses = inputs["training_losses"]
validation_losses = inputs["validation_losses"]
verbose = inputs.get("verbose", True)
print("Training on {}...".format(device))
# load pretrained model if available
if inputs.get("pretrained_model"):
print("Loading pretrained model...")
model.load_state_dict(torch.load(inputs["pretrained_model"]))
# move model to device
model.to(device)
# set bestloss to track the best validation loss and save the best model
best_loss = np.inf
# loop through epochs
for epoch in range(epochs):
train_loss = 0
valid_loss = 0
# set model to train mode and train
model.train()
# loop through batches
for left, right, ldepth, rdepth in tqdm(trainloader, desc="Training"):
# concatenate left and right images
diff = torch.abs(left - right).mean(dim=1, keepdim=True)
x = torch.cat((left, diff), dim=1)
x = x.to(device)
ldepth = ldepth.to(device)
pred_depth = model(x)
loss = criterion(pred_depth, ldepth)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
# validate
model.eval()
with torch.no_grad():
for left, right, ldepth, rdepth in tqdm(validloader, desc="Validating"):
# concatenate left and right images
diff = torch.abs(left - right).mean(dim=1, keepdim=True)
x = torch.cat((left, diff), dim=1)
x = x.to(device)
ldepth = ldepth.to(device)
pred_depth = model(x)
loss = criterion(pred_depth, ldepth)
valid_loss += loss.item()
# for visualization purposes
left, right, ldepth, _ = next(iter(validloader))
diff = torch.abs(left - right).mean(dim=1, keepdim=True)
x = torch.cat((left, diff), dim=1)
x = x.to(device)
pred_depth = model(x)
# verbose
training_losses[epoch] = train_loss / len(trainloader)
validation_losses[epoch] = valid_loss / len(validloader)
print(
f"Epoch: {epoch+1}/{epochs}.. Training Loss: {training_losses[epoch]:.6f}.. Validation Loss: {validation_losses[epoch]:.6f}"
)
if validation_losses[epoch] < best_loss:
print(
"loss decreased from {:.6f} to {:.6f} -> saving best model.".format(
best_loss, validation_losses[epoch]
)
)
best_loss = validation_losses[epoch]
torch.save(model.state_dict(), "best-model.pth")
if scheduler:
# scheduler.step(valid_loss / len(validloader))
scheduler.step()
if verbose:
# display right left and depth images
fig, ax = plt.subplots(1, 4, figsize=(15, 15))
ax[0].imshow(left[0].permute(1, 2, 0).detach().cpu().numpy())
ax[0].set_title("Left Image")
ax[1].imshow(right[0].permute(1, 2, 0).detach().cpu().numpy())
ax[1].set_title("Right Image")
ax[2].imshow(pred_depth[0].detach().cpu().squeeze().numpy())
ax[2].set_title("Depth Image")
ax[3].imshow(ldepth[0].detach().cpu().squeeze().numpy())
ax[3].set_title("Ground Truth Depth Image")
plt.show()
plt.plot(training_losses, label="Training loss")
plt.plot(validation_losses, label="Validation loss")
plt.legend(frameon=False)
plt.show()