From 7cebbbeb9a685ae55d8636a31864e46c9045159a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 25 Aug 2022 10:42:20 -0700 Subject: [PATCH] upgrade to best downsample type --- README.md | 10 ++++++++++ lightweight_gan/lightweight_gan.py | 22 ++++++++++++++++------ lightweight_gan/version.py | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0381450..530e878 100644 --- a/README.md +++ b/README.md @@ -311,4 +311,14 @@ If you want the current state of the art GAN, you can find it at https://github. } ``` +```bibtex +@article{Sunkara2022NoMS, + title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects}, + author = {Raja Sunkara and Tie Luo}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2208.03641} +} +``` + *What I cannot create, I do not understand* - Richard Feynman diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index d14fc4f..23f93b7 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -29,6 +29,7 @@ from tqdm import tqdm from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange from adabelief_pytorch import AdaBelief @@ -456,6 +457,15 @@ def init_conv_(self, conv): def forward(self, x): return self.net(x) +def SPConvDownsample(dim, dim_out = None): + # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample + # named SP-conv in the paper, but basically a pixel unshuffle + dim_out = default(dim_out, dim) + return nn.Sequential( + Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), + nn.Conv2d(dim * 4, dim_out, 1) + ) + # squeeze excitation classes # global context network @@ -611,9 +621,9 @@ def __init__( layer = nn.ModuleList([ nn.Sequential( - PixelShuffleUpsample(chan_in, chan_out), + PixelShuffleUpsample(chan_in), Blur(), - Conv2dSame(chan_out, chan_out * 2, 4), + Conv2dSame(chan_in, chan_out * 2, 4), Noise(), norm_class(chan_out * 2), nn.GLU(dim = 1) @@ -667,8 +677,8 @@ def __init__( last_layer = ind == (num_upsamples - 1) chan_out = chans if not last_layer else final_chan * 2 layer = nn.Sequential( - PixelShuffleUpsample(chans, chan_out), - nn.Conv2d(chan_out, chan_out, 3, padding = 1), + PixelShuffleUpsample(chans), + nn.Conv2d(chans, chan_out, 3, padding = 1), nn.GLU(dim = 1) ) self.layers.append(layer) @@ -743,7 +753,7 @@ def __init__( SumBranches([ nn.Sequential( Blur(), - nn.Conv2d(chan_in, chan_out, 4, stride = 2, padding = 1), + SPConvDownsample(chan_in, chan_out), nn.LeakyReLU(0.1), nn.Conv2d(chan_out, chan_out, 3, padding = 1), nn.LeakyReLU(0.1) @@ -779,7 +789,7 @@ def __init__( SumBranches([ nn.Sequential( Blur(), - nn.Conv2d(64, 32, 4, stride = 2, padding = 1), + SPConvDownsample(64, 32), nn.LeakyReLU(0.1), nn.Conv2d(32, 32, 3, padding = 1), nn.LeakyReLU(0.1) diff --git a/lightweight_gan/version.py b/lightweight_gan/version.py index 1f356cc..1a72d32 100644 --- a/lightweight_gan/version.py +++ b/lightweight_gan/version.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.1.0'