forked from taylanates24/object-detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdetector.py
executable file
·124 lines (84 loc) · 4.21 KB
/
detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import pytorch_lightning as pl
import torch.nn as nn
import torch
from models.loss import FocalLoss
from typing import Callable, List, Dict, Tuple
class Detector(pl.LightningModule):
def __init__(self, model: nn.Module, scheduler: Callable, optimizer: Callable, loss: Callable=None) -> None:
"""The detector class that is used by pytorch lightning.
Args:
model (nn.Module): The object detector class object.
scheduler (Callable): Learning rate scheduler function.
optimizer (Callable): Optimizer function.
loss (Callable, optional): Loss function. Defaults to None.
"""
super(Detector, self).__init__()
self.model = model
self.scheduler = scheduler
self.optimizer = optimizer
self.best_val_loss = float('inf')
if loss is None:
self.loss = FocalLoss()
else:
self.loss = loss
def forward(self, images: torch.tensor) -> Tuple[List, torch.tensor, torch.tensor, torch.tensor]:
return self.model(images)
def training_step(self, train_batch: Dict[str, torch.tensor], batch_idx: int) -> Dict[str, float]:
images = train_batch['img']
labels = train_batch['labels']
_, regression, classification, anchors = self.forward(images)
cls_loss, reg_loss = self.loss(classification, regression, anchors, labels)
reg_loss = reg_loss.mean()
cls_loss = cls_loss.mean()
total_loss = cls_loss + reg_loss
self.log('learning rate', self.scheduler.get_lr()[0])
return {'loss': total_loss, 'cls_loss': cls_loss, 'reg_loss': reg_loss}
def training_epoch_end(self, outputs: List[Dict[str, torch.tensor]]) -> None:
cls_losses = [x['cls_loss'] for x in outputs]
reg_losses = [x['reg_loss'] for x in outputs]
total_losses = [x['loss'] for x in outputs]
avg_train_cls_loss = sum(cls_losses) / len(cls_losses)
avg_train_reg_loss = sum(reg_losses) / len(reg_losses)
avg_train_loss = sum(total_losses) / len(total_losses)
self.log('train cls_loss', avg_train_cls_loss)
self.log('train reg_loss', avg_train_reg_loss)
self.log('train total_loss', avg_train_loss)
def validation_step(self, val_batch: Dict[str, torch.tensor], batch_idx: int) -> Dict[str, float]:
images = val_batch['img']
labels = val_batch['labels']
_, regression, classification, anchors = self.forward(images)
cls_loss, reg_loss = self.loss(classification, regression, anchors, labels)
reg_loss = reg_loss.mean()
cls_loss = cls_loss.mean()
total_loss = cls_loss + reg_loss
return {'loss':total_loss, 'cls_loss': cls_loss, 'reg_loss': reg_loss}
def validation_epoch_end(self, outputs: List[Dict[str, torch.tensor]]) -> None:
cls_losses = [x['cls_loss'] for x in outputs]
reg_losses = [x['reg_loss'] for x in outputs]
total_losses = [x['loss'] for x in outputs]
avg_val_cls_loss = sum(cls_losses) / len(cls_losses)
avg_val_reg_loss = sum(reg_losses) / len(reg_losses)
avg_val_loss = sum(total_losses) / len(total_losses)
if self.best_val_loss > avg_val_loss:
self.best_val_loss = avg_val_loss
torch.save({
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.loss,
}, './best.ckpt')
self.log('validation cls_loss', avg_val_cls_loss)
self.log('validation reg_loss', avg_val_reg_loss)
self.log('validation total_loss', avg_val_loss)
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self) -> List[Callable]:
"""The function that gets the optimizer and learning rate scheduler.
Returns:
List[Callable]: List of optimizer and scheduler.
"""
optimizer = self.optimizer
scheduler = self.scheduler
if scheduler:
return [optimizer], [scheduler]
return [optimizer]