From 79094068d36613a2860dd3684b8df5807f27d0a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=86=E6=89=AC?= Date: Fri, 2 Nov 2018 17:10:25 +0800 Subject: [PATCH] Add ShuffleNetV2 --- main.py | 3 +- models/__init__.py | 1 + models/shufflenetv2.py | 158 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 models/shufflenetv2.py diff --git a/main.py b/main.py index 26a4e988d..88686dc00 100644 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ # Model print('==> Building model..') # net = VGG('VGG19') -net = ResNet18() +# net = ResNet18() # net = PreActResNet18() # net = GoogLeNet() # net = DenseNet121() @@ -61,6 +61,7 @@ # net = DPN92() # net = ShuffleNetG2() # net = SENet18() +net = ShuffleNetV2(0.5) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) diff --git a/models/__init__.py b/models/__init__.py index d03b4c61a..7f67a9e55 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -6,6 +6,7 @@ from densenet import * from googlenet import * from shufflenet import * +from shufflenetv2 import * from resnet import * from resnext import * from preact_resnet import * diff --git a/models/shufflenetv2.py b/models/shufflenetv2.py new file mode 100644 index 000000000..a6eb00c2a --- /dev/null +++ b/models/shufflenetv2.py @@ -0,0 +1,158 @@ +'''ShuffleNetV2 in PyTorch. + +See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ShuffleBlock(nn.Module): + def __init__(self, groups=2): + super(ShuffleBlock, self).__init__() + self.groups = groups + + def forward(self, x): + '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' + N, C, H, W = x.size() + g = self.groups + return x.view(N, g, C/g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) + + +class SplitBlock(nn.Module): + def __init__(self, ratio): + super(SplitBlock, self).__init__() + self.ratio = ratio + + def forward(self, x): + c = int(x.size(1) * self.ratio) + return x[:, :c, :, :], x[:, c:, :, :] + + +class BasicBlock(nn.Module): + def __init__(self, in_channels, split_ratio=0.5): + super(BasicBlock, self).__init__() + self.split = SplitBlock(split_ratio) + in_channels = int(in_channels * split_ratio) + self.conv1 = nn.Conv2d(in_channels, in_channels, + kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(in_channels) + self.conv2 = nn.Conv2d(in_channels, in_channels, + kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) + self.bn2 = nn.BatchNorm2d(in_channels) + self.conv3 = nn.Conv2d(in_channels, in_channels, + kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(in_channels) + self.shuffle = ShuffleBlock() + + def forward(self, x): + x1, x2 = self.split(x) + out = F.relu(self.bn1(self.conv1(x2))) + out = self.bn2(self.conv2(out)) + out = F.relu(self.bn3(self.conv3(out))) + out = torch.cat([x1, out], 1) + out = self.shuffle(out) + return out + + +class DownBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(DownBlock, self).__init__() + mid_channels = out_channels // 2 + # left + self.conv1 = nn.Conv2d(in_channels, in_channels, + kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) + self.bn1 = nn.BatchNorm2d(in_channels) + self.conv2 = nn.Conv2d(in_channels, mid_channels, + kernel_size=1, bias=False) + self.bn2 = nn.BatchNorm2d(mid_channels) + # right + self.conv3 = nn.Conv2d(in_channels, mid_channels, + kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(mid_channels) + self.conv4 = nn.Conv2d(mid_channels, mid_channels, + kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) + self.bn4 = nn.BatchNorm2d(mid_channels) + self.conv5 = nn.Conv2d(mid_channels, mid_channels, + kernel_size=1, bias=False) + self.bn5 = nn.BatchNorm2d(mid_channels) + + self.shuffle = ShuffleBlock() + + def forward(self, x): + # left + out1 = self.bn1(self.conv1(x)) + out1 = F.relu(self.bn2(self.conv2(out1))) + # right + out2 = F.relu(self.bn3(self.conv3(x))) + out2 = self.bn4(self.conv4(out2)) + out2 = F.relu(self.bn5(self.conv5(out2))) + # concat + out = torch.cat([out1, out2], 1) + out = self.shuffle(out) + return out + + +class ShuffleNetV2(nn.Module): + def __init__(self, net_size): + super(ShuffleNetV2, self).__init__() + out_channels = configs[net_size]['out_channels'] + num_blocks = configs[net_size]['num_blocks'] + + self.conv1 = nn.Conv2d(3, 24, kernel_size=3, + stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(24) + self.in_channels = 24 + self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) + self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) + self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) + self.linear = nn.Linear(out_channels[2], 10) + + def _make_layer(self, out_channels, num_blocks): + layers = [DownBlock(self.in_channels, out_channels)] + for i in range(num_blocks): + layers.append(BasicBlock(out_channels)) + self.in_channels = out_channels + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + # out = F.max_pool2d(out, 3, stride=2, padding=1) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = F.avg_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +configs = { + 0.5: { + 'out_channels': (48, 96, 192), + 'num_blocks': (3, 7, 3) + }, + + 1: { + 'out_channels': (116, 232, 464), + 'num_blocks': (3, 7, 3) + }, + 1.5: { + 'out_channels': (176, 352, 704), + 'num_blocks': (3, 7, 3) + }, + 2: { + 'out_channels': (224, 488, 976), + 'num_blocks': (3, 7, 3) + } +} + + +def test(): + net = ShuffleNetV2(net_size=0.5) + x = torch.randn(3, 3, 32, 32) + y = net(x) + print(y.shape) + + +# test()