-
Notifications
You must be signed in to change notification settings - Fork 236
/
GHPA.py
100 lines (87 loc) · 4.27 KB
/
GHPA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch import nn
import torch.nn.functional as F
#可以缝合在上采样和下采样部分中
class LayerNorm(nn.Module):
""" From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf)"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Grouped_multi_axis_Hadamard_Product_Attention(nn.Module):
def __init__(self, dim_in, dim_out, x=8, y=8):
super().__init__()
c_dim_in = dim_in // 4
k_size = 3
pad = (k_size - 1) // 2
self.params_xy = nn.Parameter(torch.Tensor(1, c_dim_in, x, y), requires_grad=True)
nn.init.ones_(self.params_xy)
self.conv_xy = nn.Sequential(nn.Conv2d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in),
nn.GELU(), nn.Conv2d(c_dim_in, c_dim_in, 1))
self.params_zx = nn.Parameter(torch.Tensor(1, 1, c_dim_in, x), requires_grad=True)
nn.init.ones_(self.params_zx)
self.conv_zx = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in),
nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))
self.params_zy = nn.Parameter(torch.Tensor(1, 1, c_dim_in, y), requires_grad=True)
nn.init.ones_(self.params_zy)
self.conv_zy = nn.Sequential(nn.Conv1d(c_dim_in, c_dim_in, kernel_size=k_size, padding=pad, groups=c_dim_in),
nn.GELU(), nn.Conv1d(c_dim_in, c_dim_in, 1))
self.dw = nn.Sequential(
nn.Conv2d(c_dim_in, c_dim_in, 1),
nn.GELU(),
nn.Conv2d(c_dim_in, c_dim_in, kernel_size=3, padding=1, groups=c_dim_in)
)
self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
self.ldw = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
nn.GELU(),
nn.Conv2d(dim_in, dim_out, 1),
)
def forward(self, x):
x = self.norm1(x)
x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
B, C, H, W = x1.size()
# ----------xy----------#
params_xy = self.params_xy
x1 = x1 * self.conv_xy(F.interpolate(params_xy, size=x1.shape[2:4], mode='bilinear', align_corners=True))
# ----------zx----------#
x2 = x2.permute(0, 3, 1, 2)
params_zx = self.params_zx
x2 = x2 * self.conv_zx(
F.interpolate(params_zx, size=x2.shape[2:4], mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x2 = x2.permute(0, 2, 3, 1)
# ----------zy----------#
x3 = x3.permute(0, 2, 1, 3)
params_zy = self.params_zy
x3 = x3 * self.conv_zy(
F.interpolate(params_zy, size=x3.shape[2:4], mode='bilinear', align_corners=True).squeeze(0)).unsqueeze(0)
x3 = x3.permute(0, 2, 1, 3)
# ----------dw----------#
x4 = self.dw(x4)
# ----------concat----------#
x = torch.cat([x1, x2, x3, x4], dim=1)
# ----------ldw----------#
x = self.norm2(x)
x = self.ldw(x)
return x
if __name__ == '__main__':
block = Grouped_multi_axis_Hadamard_Product_Attention(dim_in=64, dim_out=128)
input = torch.randn(1, 64, 32, 32) # B C H W
print(input.size())
output = block(input)
print(output.size())