Skip to content

Commit

Permalink
testing out an attention-based upsampling that https://arxiv.org/abs/… (
Browse files Browse the repository at this point in the history
#125)

* attention-based upsampling https://arxiv.org/abs/2112.11435 - better results than either bilinear upsample or convtranspose2d

* 0.23.0
  • Loading branch information
lucidrains authored Apr 26, 2022
1 parent 72d6a98 commit 8fe02e7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 5 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,14 @@ If you want the current state of the art GAN, you can find it at https://github.
}
```

```bibtex
@article{Arar2021LearnedQF,
title = {Learned Queries for Efficient Local Attention},
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
journal = {ArXiv},
year = {2021},
volume = {abs/2112.11435}
}
```

*What I cannot create, I do not understand* - Richard Feynman
113 changes: 109 additions & 4 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,111 @@ def Conv2dSame(dim_in, dim_out, kernel_size, bias = True):
nn.Conv2d(dim_in, dim_out, kernel_size, bias = bias)
)

# attention-based upsampling
# from https://arxiv.org/abs/2112.11435

class QueryAndAttend(nn.Module):
def __init__(
self,
*,
dim,
num_queries = 1,
dim_head = 32,
heads = 8,
window_size = 3
):
super().__init__()
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.window_size = window_size
self.num_queries = num_queries

self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))

self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)

def forward(self, x):
"""
einstein notation
b - batch
h - heads
l - num queries
d - head dimension
x - height
y - width
j - source sequence for attending to (kernel size squared in this case)
"""

wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
batch, _, height, width = x.shape

is_one_query = self.num_queries == 1

# queries, keys, values

q = self.queries * self.scale
k, v = self.to_kv(x).chunk(2, dim = 1)

# similarities

sim = einsum('h l d, b d x y -> b h l x y', q, k)
sim = rearrange(sim, 'b ... x y -> b (...) x y')

# unfold the similarity scores, with float(-inf) as padding value

mask_value = -torch.finfo(sim.dtype).max
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
sim = F.unfold(sim, kernel_size = wsz)
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)

# rel pos bias

sim = sim + self.rel_pos_bias

# numerically stable attention

sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)

# unfold values

v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
v = F.unfold(v, kernel_size = wsz)
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)

# aggregate values

out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)

# combine heads

out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
out = self.to_out(out)
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)

# return original input if one query

if is_one_query:
out = rearrange(out, 'b 1 ... -> b ...')

return out

class QueryAttnUpsample(nn.Module):
def __init__(self, dim, **kwargs):
super().__init__()
self.norm = ChanNorm(dim)
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)

def forward(self, x):
x = self.norm(x)
out = self.qna(x)
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
return out

# attention

class DepthWiseConv2d(nn.Module):
Expand Down Expand Up @@ -430,8 +535,8 @@ def forward(self, images, prob = 0., types = [], detach = False, **kwargs):

norm_class = nn.BatchNorm2d

def upsample(scale_factor = 2):
return nn.Upsample(scale_factor = scale_factor)
def upsample(dim):
return QueryAttnUpsample(dim = dim)

# squeeze excitation classes

Expand Down Expand Up @@ -588,7 +693,7 @@ def __init__(

layer = nn.ModuleList([
nn.Sequential(
upsample(),
upsample(chan_in),
Blur(),
Conv2dSame(chan_in, chan_out * 2, 4),
Noise(),
Expand Down Expand Up @@ -644,7 +749,7 @@ def __init__(
last_layer = ind == (num_upsamples - 1)
chan_out = chans if not last_layer else final_chan * 2
layer = nn.Sequential(
upsample(),
upsample(chans),
nn.Conv2d(chans, chan_out, 3, padding = 1),
nn.GLU(dim = 1)
)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.22.3'
__version__ = '0.23.0'

0 comments on commit 8fe02e7

Please sign in to comment.