Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This code mentions the issue #104 #131

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions models/dropblock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
# seed_drop_rate, the gamma parameter
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))

# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)

if batchwise:
# one mask for whole batch, quite a bit faster
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
else:
uniform_noise = torch.rand_like(x)
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
block_mask = -F.max_pool2d(
-block_mask,
kernel_size=clipped_block_size, # block_size,
stride=1,
padding=clipped_block_size // 2)

if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
else:
x = x * block_mask + normal_noise * (1 - block_mask)
else:
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x


def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
block mask at edges.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))

if batchwise:
# one mask for whole batch, quite a bit faster
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
else:
# mask per batch element
block_mask = torch.rand_like(x) < gamma
block_mask = F.max_pool2d(
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)

if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
else:
x = x * (1. - block_mask) + normal_noise * block_mask
else:
block_mask = 1 - block_mask
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x


class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
def __init__(self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
self.block_size = block_size
self.with_noise = with_noise
self.inplace = inplace
self.batchwise = batchwise
self.fast = fast # FIXME finish comparisons of fast vs not

def forward(self, x):
if not self.training or not self.drop_prob:
return x
if self.fast:
return drop_block_fast_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
else:
return drop_block_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)


def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
181 changes: 181 additions & 0 deletions models/efficientnetB0_Regularized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from dropblock import DropBlock2d


def swish(x):
return x * x.sigmoid()


def drop_connect(x, drop_ratio):
keep_ratio = 1.0 - drop_ratio
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
mask.bernoulli_(keep_ratio)
x.div_(keep_ratio)
x.mul_(mask)
return x


class SE(nn.Module):
'''Squeeze-and-Excitation block with Swish.'''

def __init__(self, in_channels, se_channels):
super(SE, self).__init__()
self.se1 = nn.Conv2d(in_channels, se_channels,
kernel_size=1, bias=True)
self.drop1 = DropBlock2d(block_size=1)
self.se2 = nn.Conv2d(se_channels, in_channels,
kernel_size=1, bias=True)
self.drop2 = DropBlock2d(block_size=1)

def forward(self, x):
out = F.adaptive_avg_pool2d(x, (1, 1))
out = swish(self.se1(out))
out = self.drop1(out)
out = self.se2(out).sigmoid()
out = self.drop2(out)
out = x * out
return out


class Block(nn.Module):
'''expansion + depthwise + pointwise + squeeze-excitation'''

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
expand_ratio=1,
se_ratio=0.,
drop_rate=0.):
super(Block, self).__init__()
self.stride = stride
self.drop_rate = drop_rate
self.expand_ratio = expand_ratio

# Expansion
channels = expand_ratio * in_channels
self.conv1 = nn.Conv2d(in_channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.drop1 = DropBlock2d(block_size=1)

# Depthwise conv
self.conv2 = nn.Conv2d(channels,
channels,
kernel_size=kernel_size,
stride=stride,
padding=(1 if kernel_size == 3 else 2),
groups=channels,
bias=False)
self.bn2 = nn.BatchNorm2d(channels)
self.drop2 = DropBlock2d(block_size=1)

# SE layers
se_channels = int(in_channels * se_ratio)
self.se = SE(channels, se_channels)

# Output
self.conv3 = nn.Conv2d(channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
self.drop3 = DropBlock2d(block_size=1)

# Skip connection if in and out shapes are the same (MV-V2 style)
self.has_skip = (stride == 1) and (in_channels == out_channels)

def forward(self, x):
out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x)))
out = self.drop1(out)
out = swish(self.bn2(self.conv2(out)))
out = self.drop2(out)
out = self.se(out)
out = self.bn3(self.conv3(out))
out = self.drop3(out)

if self.has_skip:
if self.training and self.drop_rate > 0:
out = drop_connect(out, self.drop_rate)
out = out + x
return out


class EfficientNet(nn.Module):
def __init__(self, cfg, num_classes=10):
super(EfficientNet, self).__init__()
self.cfg = cfg
self.conv1 = nn.Conv2d(3,
32,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.drop1 = DropBlock2d(block_size=1)
self.layers = self._make_layers(in_channels=32)
self.linear = nn.Linear(cfg['out_channels'][-1], num_classes)

def _make_layers(self, in_channels):
layers = []
cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size',
'stride']]
b = 0
blocks = sum(self.cfg['num_blocks'])
for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg):
strides = [stride] + [1] * (num_blocks - 1)
for stride in strides:
drop_rate = self.cfg['drop_connect_rate'] * b / blocks
layers.append(
Block(in_channels,
out_channels,
kernel_size,
stride,
expansion,
se_ratio=0.25,
drop_rate=drop_rate))
in_channels = out_channels
return nn.Sequential(*layers)

def forward(self, x):
out = swish(self.bn1(self.conv1(x)))
out = self.drop1(out)
out = self.layers(out)
out = F.adaptive_avg_pool2d(out, 1)
out = out.view(out.size(0), -1)
dropout_rate = self.cfg['dropout_rate']
if self.training and dropout_rate > 0:
out = F.dropout(out, p=dropout_rate)
out = self.linear(out)
return out


def EfficientNetB0():
cfg = {
'num_blocks': [1, 2, 2, 3, 3, 4, 1],
'expansion': [1, 6, 6, 6, 6, 6, 6],
'out_channels': [16, 24, 40, 80, 112, 192, 320],
'kernel_size': [3, 3, 5, 3, 5, 5, 3],
'stride': [1, 2, 2, 2, 1, 2, 1],
'dropout_rate': 0.2,
'drop_connect_rate': 0.2,
}
return EfficientNet(cfg)


def test():
net = EfficientNetB0()
x = torch.randn(2, 3, 32, 32)
y = net(x)
print(y.shape)