Skip to content

Commit

Permalink
update icnet
Browse files Browse the repository at this point in the history
  • Loading branch information
liminn committed Nov 13, 2019
1 parent 7fbe397 commit 28f5317
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions core/models/icnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained
_ConvBNReLU(32, 64, 3, 2, **kwargs)
)

self.ppm = PyramidPoolingModule()

self.head = _ICHead(nclass, **kwargs)

self.__setattr__('exclusive', ['conv_sub1', 'head'])
Expand All @@ -35,16 +37,31 @@ def forward(self, x):
# sub 4
x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True)
_, _, _, x_sub4 = self.base_forward(x_sub4)

# add PyramidPoolingModule
x_sub4 = self.ppm(x_sub4)
outputs = self.head(x_sub1, x_sub2, x_sub4)

return tuple(outputs)

class PyramidPoolingModule(nn.Module):
def __init__(self, pyramids=[1,2,3,6]):
super(PyramidPoolingModule, self).__init__()
self.pyramids = pyramids

def forward(self, input):
feat = input
height, width = input.shape[2:]
for bin_size in self.pyramids:
x = F.adaptive_avg_pool2d(input, output_size=bin_size)
x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
feat = feat + x
return feat

class _ICHead(nn.Module):
def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
super(_ICHead, self).__init__()
self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs)
#self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs)
self.cff_12 = CascadeFeatureFusion(128, 64, 128, nclass, norm_layer, **kwargs)
self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs)

self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False)
Expand All @@ -53,9 +70,10 @@ def forward(self, x_sub1, x_sub2, x_sub4):
outputs = list()
x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
outputs.append(x_24_cls)
x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
#x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
outputs.append(x_12_cls)

up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True)
up_x2 = self.conv_cls(up_x2)
outputs.append(up_x2)
Expand Down

0 comments on commit 28f5317

Please sign in to comment.