Skip to content

Commit

Permalink
fix for what may be an important detail left out of the QnA paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 26, 2022
1 parent 8fe02e7 commit f598cb7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
self.num_queries = num_queries

self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
self.attn_weight = nn.Parameter(torch.ones(heads, num_queries, window_size * window_size, 1, 1))

self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
Expand Down Expand Up @@ -312,6 +313,8 @@ def forward(self, x):
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)

attn = attn * self.attn_weight

# unfold values

v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.23.0'
__version__ = '0.23.1'

0 comments on commit f598cb7

Please sign in to comment.