-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
nonlocalnet.py
68 lines (66 loc) · 3.42 KB
/
nonlocalnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
'''
Function:
Implementation of NonLocalNet
Author:
Zhenchao Jin
'''
import copy
import torch
import torch.nn as nn
from ..base import BaseSegmentor
from .nonlocalblock import NonLocal2d
from ....utils import SSSegOutputStructure
from ...backbones import BuildActivation, BuildNormalization
'''NonLocalNet'''
class NonLocalNet(BaseSegmentor):
def __init__(self, cfg, mode):
super(NonLocalNet, self).__init__(cfg, mode)
align_corners, norm_cfg, act_cfg, head_cfg = self.align_corners, self.norm_cfg, self.act_cfg, cfg['head']
# build non-local block
self.conv_before_nl = nn.Sequential(
nn.Conv2d(head_cfg['in_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
self.nl_block = NonLocal2d(
in_channels=head_cfg['feats_channels'], reduction=head_cfg['reduction'], use_scale=head_cfg['use_scale'], mode=head_cfg['mode'],
norm_cfg=copy.deepcopy(norm_cfg), act_cfg=copy.deepcopy(act_cfg),
)
self.conv_after_nl = nn.Sequential(
nn.Conv2d(head_cfg['feats_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
)
# build decoder
self.decoder = nn.Sequential(
nn.Conv2d(head_cfg['in_channels'] + head_cfg['feats_channels'], head_cfg['feats_channels'], kernel_size=3, stride=1, padding=1, bias=False),
BuildNormalization(placeholder=head_cfg['feats_channels'], norm_cfg=norm_cfg),
BuildActivation(act_cfg),
nn.Dropout2d(head_cfg['dropout']),
nn.Conv2d(head_cfg['feats_channels'], cfg['num_classes'], kernel_size=1, stride=1, padding=0)
)
# build auxiliary decoder
self.setauxiliarydecoder(cfg['auxiliary'])
# freeze normalization layer if necessary
if cfg.get('is_freeze_norm', False): self.freezenormalization()
'''forward'''
def forward(self, data_meta):
img_size = data_meta.images.size(2), data_meta.images.size(3)
# feed to backbone network
backbone_outputs = self.transforminputs(self.backbone_net(data_meta.images), selected_indices=self.cfg['backbone'].get('selected_indices'))
# feed to non-local block
feats = self.conv_before_nl(backbone_outputs[-1])
feats = self.nl_block(feats)
feats = self.conv_after_nl(feats)
# feed to decoder
feats = torch.cat([backbone_outputs[-1], feats], dim=1)
seg_logits = self.decoder(feats)
# forward according to the mode
if self.mode in ['TRAIN', 'TRAIN_DEVELOP']:
loss, losses_log_dict = self.customizepredsandlosses(
seg_logits=seg_logits, annotations=data_meta.getannotations(), backbone_outputs=backbone_outputs, losses_cfg=self.cfg['losses'], img_size=img_size,
)
ssseg_outputs = SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict) if self.mode == 'TRAIN' else SSSegOutputStructure(mode=self.mode, loss=loss, losses_log_dict=losses_log_dict, seg_logits=seg_logits)
else:
ssseg_outputs = SSSegOutputStructure(mode=self.mode, seg_logits=seg_logits)
return ssseg_outputs