forked from r3krut/KITTI_ROAD_SEGMENTATION
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
208 lines (176 loc) · 9.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
The Main module
"""
import cv2
import sys
import argparse
import logging
import numpy as np
from pathlib import Path
import utils.utils as utils
import utils.img_utils as imutils
from models.reknetm1 import RekNetM1
from models.reknetm2 import RekNetM2
from models.lidcamnet_fcn import LidCamNet
from data_processing.road_dataset import RoadDataset, RoadDataset2
from data_processing.data_processing import (
droped_valid_image_2_dir,
train_masks_dir,
crossval_split,
image_2_dir
)
from misc.losses import BCEJaccardLoss, CCEJaccardLoss
from misc.polylr_scheduler import PolyLR
from misc.transforms import (
train_transformations,
valid_tranformations,
)
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.backends.cudnn
#For reproducibility
torch.manual_seed(111)
def main(*args, **kwargs):
parser = argparse.ArgumentParser(description="Argument parser for the main module. Main module represents train procedure.")
parser.add_argument("--root-dir", type=str, required=True, help="Path to the root dir where will be stores models.")
parser.add_argument("--dataset-path", type=str, required=True, help="Path to the KITTI dataset which contains 'testing' and 'training' subdirs.")
parser.add_argument("--fold", type=int, default=1, help="Num of a validation fold.")
#optimizer options
parser.add_argument("--optim", type=str, default="SGD", help="Type of optimizer: SGD or Adam")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rates for optimizer.")
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum for SGD optim.")
#Scheduler options
parser.add_argument("--scheduler", type=str, default="multi-step", help="Type of a scheduler for LR scheduling.")
parser.add_argument("--step-st", type=int, default=5, help="Step size for StepLR scheudle.")
parser.add_argument("--milestones", type=str, default="30,70,90", help="List with milestones for MultiStepLR schedule.")
parser.add_argument("--gamma", type=float, default=0.1, help="Gamma parameter for StepLR and MultiStepLR schedule.")
parser.add_argument("--patience", type=int, default=5, help="Patience parameter for ReduceLROnPlateau schedule.")
#model params
parser.add_argument("--model-type", type=str, default="reknetm1", help="Type of model. Can be 'RekNetM1' or 'RekNetM2'.")
parser.add_argument("--decoder-type", type=str, default="up", help="Type of decoder module. Can be 'up'(Upsample) or 'ConvTranspose2D'.")
parser.add_argument("--init-type", type=str, default="He", help="Initialization type. Can be 'He' or 'Xavier'.")
parser.add_argument("--act-type", type=str, default="relu", help="Activation type. Can be ReLU, CELU or FTSwish+.")
parser.add_argument("--enc-bn-enable", type=int, default=1, help="Batch normalization enabling in encoder module.")
parser.add_argument("--dec-bn-enable", type=int, default=1, help="Batch normalization enabling in decoder module.")
parser.add_argument("--skip-conn", type=int, default=0, help="Skip-connection in context module.")
parser.add_argument("--attention", type=int, default=0, help="Attention mechanism in context module.")
#other options
parser.add_argument("--n-epochs", type=int, default=100, help="Number of training epochs.")
parser.add_argument("--batch-size", type=int, default=4, help="Number of examples per batch.")
parser.add_argument("--num-workers", type=int, default=8, help="Number of loading workers.")
parser.add_argument("--device-ids", type=str, default="0", help="ID of devices for multiple GPUs.")
parser.add_argument("--alpha", type=float, default=0, help="Modulation factor for custom loss.")
parser.add_argument("--status-every", type=int, default=1, help="Status every parameter.")
args = parser.parse_args()
#Console logger definition
console_logger = logging.getLogger("console-logger")
console_logger.setLevel(logging.INFO)
ch = logging.StreamHandler(stream=sys.stdout)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
console_logger.addHandler(ch)
console_logger.info(args)
#number of classes
num_classes = 1
if args.decoder_type == "up":
upsample_enable = True
console_logger.info("Decoder type is Upsample.")
elif args.decoder_type == "convTr":
upsample_enable = False
console_logger.info("Decoder type is ConvTranspose2D.")
#Model definition
if args.model_type == "reknetm1":
model = RekNetM1(num_classes=num_classes,
ebn_enable=bool(args.enc_bn_enable),
dbn_enable=bool(args.dec_bn_enable),
upsample_enable=upsample_enable,
act_type=args.act_type,
init_type=args.init_type)
console_logger.info("Uses RekNetM1 as the model.")
elif args.model_type == "reknetm2":
model = RekNetM2(num_classes=num_classes,
ebn_enable=bool(args.enc_bn_enable),
dbn_enable=bool(args.dec_bn_enable),
upsample_enable=upsample_enable,
act_type=args.act_type,
init_type=args.init_type,
attention=bool(args.attention),
use_skip=bool(args.skip_conn))
console_logger.info("Uses RekNetM2 as the model.")
elif args.model_type == "lcn":
model = LidCamNet(num_classes=num_classes,
bn_enable=False)
console_logger.info("Uses LinCamNet as the model.")
else:
raise ValueError("Unknown model type: {}".format(args.model_type))
console_logger.info("Number of trainable parameters: {}".format(utils.count_params(model)[1]))
#Move model to devices
if torch.cuda.is_available():
if args.device_ids:
device_ids = list(map(int, args.device_ids.split(',')))
else:
device_ids = None
model = nn.DataParallel(model, device_ids=device_ids).cuda()
cudnn.benchmark = True
#Loss definition
loss = BCEJaccardLoss(alpha=args.alpha)
dataset_path = Path(args.dataset_path)
images = str(dataset_path / "training" / droped_valid_image_2_dir)
masks = str(dataset_path / "training" / train_masks_dir)
#train-val splits for cross-validation by a fold (KITTI specific)
((train_imgs, train_masks),
(valid_imgs, valid_masks)) = crossval_split(images_paths=images, masks_paths=masks, fold=args.fold)
# Define training/validation/ dataset
train_dataset = RoadDataset2(img_paths=train_imgs, mask_paths=train_masks, transforms=train_transformations())
valid_dataset = RoadDataset2(img_paths=valid_imgs, mask_paths=valid_masks, transforms=valid_tranformations())
valid_fmeasure_datset = RoadDataset2(img_paths=valid_imgs, mask_paths=valid_masks, transforms=valid_tranformations(), fmeasure_eval=True)
# Create Data Loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
valid_loader = DataLoader(dataset=valid_dataset, batch_size=torch.cuda.device_count(), num_workers=args.num_workers, pin_memory=torch.cuda.is_available())
console_logger.info("Train dataset length: {}".format(len(train_dataset)))
console_logger.info("Validation dataset length: {}".format(len(valid_dataset)))
#Optim definition
if args.optim == "SGD":
optim = SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum)
console_logger.info("Uses the SGD optimizer with initial lr={0} and momentum={1}".format(args.lr, args.momentum))
else:
optim = Adam(params=model.parameters(), lr=args.lr)
console_logger.info("Uses the Adam optimizer with initial lr={0}".format(args.lr))
if args.scheduler == "step":
lr_scheduler = StepLR(optimizer=optim, step_size=args.step_st, gamma=args.gamma)
console_logger.info("Uses the StepLR scheduler with step={} and gamma={}.".format(args.step_st, args.gamma))
elif args.scheduler == "multi-step":
lr_scheduler = MultiStepLR(optimizer=optim, milestones=[int(m) for m in (args.milestones).split(",")], gamma=args.gamma)
console_logger.info("Uses the MultiStepLR scheduler with milestones=[{}] and gamma={}.".format(args.milestones, args.gamma))
elif args.scheduler == "rlr-plat":
lr_scheduler = ReduceLROnPlateau(optimizer=optim, patience=args.patience, verbose=True)
console_logger.info("Uses the ReduceLROnPlateau scheduler.")
elif args.scheduler == "poly":
lr_scheduler = PolyLR(optimizer=optim, num_epochs=args.n_epochs, alpha=args.gamma)
console_logger.info("Uses the PolyLR scheduler.")
else:
raise ValueError("Unknown type of schedule: {}".format(args.scheduler))
valid = utils.binary_validation_routine
utils.train_routine(
args=args,
console_logger=console_logger,
root=args.root_dir,
model=model,
criterion=loss,
optimizer=optim,
scheduler=lr_scheduler,
train_loader=train_loader,
valid_loader=valid_loader,
fm_eval_dataset=valid_fmeasure_datset,
validation=valid,
fold=args.fold,
num_classes=num_classes,
n_epochs=args.n_epochs,
status_every=args.status_every
)
if __name__ == "__main__":
main()