forked from taylanates24/object-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·85 lines (62 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pytorch_lightning as pl
import yaml
from models.model import TyNet
from utils.utils import init_weights
import argparse
from data.coco_dataset import CustomDataset, collater
from data.augmentations import get_augmentations
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from models.loss import FocalLoss
from models.utils import get_optimizer, get_scheduler
from models.detector import Detector
import torch
if __name__ == '__main__':
torch.cuda.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--train_cfg', type=str, default='training.yaml', help='training config file')
parser.add_argument('--dataset_cfg', type=str, default='coco.yml', help='training config file')
args = parser.parse_args()
opt = args.train_cfg
with open(opt, 'r') as config:
opt = yaml.safe_load(config)
dataset_opt = args.dataset_cfg
with open(dataset_opt, 'r') as config:
dataset_opt = yaml.safe_load(config)
model = TyNet(num_classes=len(dataset_opt['obj_list']),
ratios=eval(dataset_opt['anchors_ratios']),
scales=eval(dataset_opt['anchors_scales']))
init_weights(model)
model = model.cuda()
augmentations = get_augmentations(opt)
training_params = {'batch_size': opt['training']['batch_size'],
'shuffle': opt['training']['shuffle'],
'drop_last': opt['training']['drop_last'],
'collate_fn': collater,
'num_workers': opt['training']['num_workers']}
val_params = {'batch_size': 1,
'shuffle': opt['validation']['shuffle'],
'drop_last': opt['validation']['drop_last'],
'collate_fn': collater,
'num_workers': opt['validation']['num_workers']}
train_dataset = CustomDataset(image_path=opt['training']['image_path'],
annotation_path=opt['training']['annotation_path'],
image_size=opt['training']['image_size'],
normalize=opt['training']['normalize'],
augmentations=augmentations)
val_dataset = CustomDataset(image_path=opt['validation']['image_path'],
annotation_path=opt['validation']['annotation_path'],
image_size=opt['training']['image_size'],
normalize=opt['training']['normalize'],
augmentations=None)
train_loader = DataLoader(train_dataset, **training_params)
val_loader = DataLoader(val_dataset, **val_params)
logger = TensorBoardLogger("tb_logs", name="my_model")
loss_fn = FocalLoss()
optimizer = get_optimizer(opt['training'], model)
scheduler = get_scheduler(opt['training'], optimizer, len(train_loader))
detector = Detector(model=model, scheduler=scheduler, optimizer=optimizer, loss=loss_fn)
trainer = pl.Trainer(gpus=1, logger=logger, check_val_every_n_epoch=opt['training']['val_frequency'], max_epochs=opt['training']['epochs'])
trainer.fit(model=detector,
train_dataloaders=train_loader,
val_dataloaders=val_loader)