-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
dnlnet.py
68 lines (66 loc) · 3.48 KB
/
dnlnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
'''
Function:
Implementation of DNLNet
Author:
Zhenchao Jin
'''
import copy
import torch
import torch.nn as nn
from ..base import BaseSegmentor
from ....utils import SSSegOutputStructure
from .dnlblock import DisentangledNonLocal2d
from ...backbones import BuildActivation, BuildNormalization
'''DNLNet'''
class DNLNet(BaseSegmentor):
def __init__(self, cfg, mode):
super(DNLNet, self).__init__(cfg, mode)
align_corners, norm_cfg, act_cfg, head_cfg = self.align_corners, self.norm_cfg, self.act_cfg, cfg['head']
# build disentangled non-local block
self.conv_before_dnl = nn.Sequential(
nn.Conv2d(head_cfg['in_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
self.dnl_block = DisentangledNonLocal2d(
in_channels=head_cfg['feats_channels'], reduction=head_cfg['reduction'], use_scale=head_cfg['use_scale'], mode=head_cfg['mode'],
temperature=head_cfg['temperature'], norm_cfg=copy.deepcopy(norm_cfg), act_cfg=copy.deepcopy(act_cfg),
)
self.conv_after_dnl = nn.Sequential(
nn.Conv2d(head_cfg['feats_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
# build decoder
self.decoder = nn.Sequential(
nn.Conv2d(head_cfg['feats_channels'] + head_cfg['in_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
nn.Dropout2d(head_cfg['dropout']),
nn.Conv2d(head_cfg['feats_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0)
)
# build auxiliary decoder
self.setauxiliarydecoder(cfg['auxiliary'])
# freeze normalization layer if necessary
if cfg.get('is_freeze_norm', False): self.freezenormalization()
'''forward'''
def forward(self, data_meta):
img_size = data_meta.images.size(2), data_meta.images.size(3)
# feed to backbone network
backbone_outputs = self.transforminputs(self.backbone_net(data_meta.images), selected_indices=self.cfg['backbone'].get('selected_indices'))
# feed to disentangled non-local block
feats = self.conv_before_dnl(backbone_outputs[-1])
feats = self.dnl_block(feats)
feats = self.conv_after_dnl(feats)
# feed to decoder
feats = torch.cat([backbone_outputs[-1], feats], dim=1)
seg_logits = self.decoder(feats)
# forward according to the mode
if self.mode in ['TRAIN', 'TRAIN_DEVELOP']:
loss, losses_log_dict = self.customizepredsandlosses(
seg_logits=seg_logits, annotations=data_meta.getannotations(), backbone_outputs=backbone_outputs, losses_cfg=self.cfg['losses'], img_size=img_size,
)
ssseg_outputs = SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict) if self.mode == 'TRAIN' else SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict, seg_logits=seg_logits)
else:
ssseg_outputs = SSSegOutputStructure(mode=self.mode, seg_logits=seg_logits)
return ssseg_outputs