forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
(CVPR 2024)ASSA.py
124 lines (107 loc) · 5.58 KB
/
(CVPR 2024)ASSA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from einops import repeat
# 论文:Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration, CVPR 2024.
# 论文地址:https://openaccess.thecvf.com/content/CVPR2024/papers/Zhou_Adapt_or_Perish_Adaptive_Sparse_Transformer_with_Attentive_Feature_Refinement_CVPR_2024_paper.pdf
# 全网最全100➕即插即用模块GitHub地址:https://github.com/ai-dawang/PlugNPlay-Modules
class LinearProjection(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, bias=True):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias = bias)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
self.dim = dim
self.inner_dim = inner_dim
def forward(self, x, attn_kv=None):
B_, N, C = x.shape
if attn_kv is not None:
attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1)
else:
attn_kv = x
N_kv = attn_kv.size(1)
q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q = q[0]
k, v = kv[0], kv[1]
return q,k,v
# Adaptive Sparse Self-Attention (ASSA)
class WindowAttention_sparse(nn.Module):
def __init__(self, dim, win_size, num_heads=8, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.dim = dim
self.win_size = win_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.win_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
if token_projection == 'linear':
self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
raise Exception("Projection error!")
self.token_projection = token_projection
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.relu = nn.ReLU()
self.w = nn.Parameter(torch.ones(2))
def forward(self, x, attn_kv=None, mask=None):
B_, N, C = x.shape
q, k, v = self.qkv(x, attn_kv)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
ratio = attn.size(-1) // relative_position_bias.size(-1)
relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N * ratio)
attn0 = self.softmax(attn)
attn1 = self.relu(attn) ** 2 # b,h,w,c
else:
attn0 = self.softmax(attn)
attn1 = self.relu(attn) ** 2
w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
attn = attn0 * w1 + attn1 * w2
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
if __name__ == '__main__':
# Instantiate the WindowAttention_sparse class
dim = 64 # Dimension of input features
win_size = (64, 64) # Window size(H, W)
# Create an instance of the WindowAttention_sparse module
window_attention_sparse = WindowAttention_sparse(dim, win_size)
C = dim
input = torch.randn(1, 64 * 64, C)#输入B H W
# Forward pass
output = window_attention_sparse(input)
# Print input and output size
print(input.size())
print(output.size())