From 8783c312d201ffca93b0ea62c25f897545ab0465 Mon Sep 17 00:00:00 2001 From: "Daniel J. Hofmann" Date: Thu, 30 May 2019 23:35:59 +0200 Subject: [PATCH] Adds feature pyramid attention (FPA) module, resolves #167 --- robosat/fpa.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ robosat/unet.py | 8 ++++++- 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 robosat/fpa.py diff --git a/robosat/fpa.py b/robosat/fpa.py new file mode 100644 index 00000000..63c39b6a --- /dev/null +++ b/robosat/fpa.py @@ -0,0 +1,59 @@ +"""Feature Pyramid Attention blocks + +See: +- https://arxiv.org/abs/1805.10180 - Pyramid Attention Network for Semantic Segmentation + +""" + +import torch.nn as nn + + +class FeaturePyramidAttention(nn.Module): + """Feature Pyramid Attetion (FPA) block + See https://arxiv.org/abs/1805.10180 Figure 3 b + """ + + def __init__(self, num_in, num_out): + super().__init__() + + # no batch norm for tensors of shape NxCx1x1 + self.top1x1 = nn.Sequential(nn.Conv2d(num_in, num_out, 1, bias=False), nn.ReLU(inplace=True)) + + self.mid1x1 = ConvBnRelu(num_in, num_out, 1) + + self.bot5x5 = ConvBnRelu(num_in, num_in, 5, stride=2, padding=2) + self.bot3x3 = ConvBnRelu(num_in, num_in, 3, stride=2, padding=1) + + self.lat5x5 = ConvBnRelu(num_in, num_out, 5, stride=1, padding=2) + self.lat3x3 = ConvBnRelu(num_in, num_out, 3, stride=1, padding=1) + + def forward(self, x): + assert x.size()[-1] % 8 == 0 and x.size()[-2] % 8 == 0, "size has to be divisible by 8 for fpa" + + # global pooling top pathway + top = self.top1x1(nn.functional.adaptive_avg_pool2d(x, 1)) + top = nn.functional.interpolate(top, size=x.size()[-2:], mode="bilinear") + + # conv middle pathway + mid = self.mid1x1(x) + + # multi-scale bottom and lateral pathways + bot0 = self.bot5x5(x) + bot1 = self.bot3x3(bot0) + + lat0 = self.lat5x5(bot0) + lat1 = self.lat3x3(bot1) + + # upward accumulation pathways + up = lat0 + nn.functional.interpolate(lat1, scale_factor=2, mode="bilinear") + up = nn.functional.interpolate(up, scale_factor=2, mode="bilinear") + + return up * mid + top + + +def ConvBnRelu(num_in, num_out, kernel_size, stride=1, padding=0, bias=False): + return nn.Sequential( + nn.Conv2d(num_in, num_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + nn.BatchNorm2d(num_out, num_out), + nn.ReLU(inplace=True), + ) diff --git a/robosat/unet.py b/robosat/unet.py index 1acc4ac7..1d259985 100644 --- a/robosat/unet.py +++ b/robosat/unet.py @@ -14,6 +14,8 @@ from torchvision.models import resnet50 +from robosat.fpa import FeaturePyramidAttention + class ConvRelu(nn.Module): """3x3 convolution followed by ReLU activation building block. @@ -96,6 +98,8 @@ def __init__(self, num_classes, num_filters=32, pretrained=True): # Access resnet directly in forward pass; do not store refs here due to # https://github.com/pytorch/pytorch/issues/8392 + self.fpa = FeaturePyramidAttention(2048, 2048) + self.center = DecoderBlock(2048, num_filters * 8) self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8) @@ -129,7 +133,9 @@ def forward(self, x): enc3 = self.resnet.layer3(enc2) enc4 = self.resnet.layer4(enc3) - center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2)) + fpa = self.fpa(enc4) + + center = self.center(nn.functional.max_pool2d(fpa, kernel_size=2, stride=2)) dec0 = self.dec0(torch.cat([enc4, center], dim=1)) dec1 = self.dec1(torch.cat([enc3, dec0], dim=1))