-
Notifications
You must be signed in to change notification settings - Fork 236
/
BFAM.py
97 lines (73 loc) · 3.24 KB
/
BFAM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#论文:B2CNet: A Progressive Change Boundary-to-Center Refinement Network for Multitemporal Remote Sensing Images Change Detection
#论文地址:https://ieeexplore.ieee.org/document/10547405
import torch
import torch.nn as nn
#Simam: A simple, parameter-free attention module for convolutional neural networks (ICML 2021)
class simam_module(torch.nn.Module):
def __init__(self, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
#bitemporal feature aggregation module (BFAM)
class BFAM(nn.Module):
def __init__(self,inp,out):
super(BFAM, self).__init__()
self.pre_siam = simam_module()
self.lat_siam = simam_module()
out_1 = int(inp/2)
self.conv_1 = nn.Conv2d(inp, out_1 , padding=1, kernel_size=3,groups=out_1,
dilation=1)
self.conv_2 = nn.Conv2d(inp, out_1, padding=2, kernel_size=3,groups=out_1,
dilation=2)
self.conv_3 = nn.Conv2d(inp, out_1, padding=3, kernel_size=3,groups=out_1,
dilation=3)
self.conv_4 = nn.Conv2d(inp, out_1, padding=4, kernel_size=3,groups=out_1,
dilation=4)
self.fuse = nn.Sequential(
nn.Conv2d(out_1 * 4, out_1, kernel_size=1, padding=0),
nn.BatchNorm2d(out_1),
nn.ReLU(inplace=True)
)
self.fuse_siam = simam_module()
self.out = nn.Sequential(
nn.Conv2d(out_1, out, kernel_size=3, padding=1),
nn.BatchNorm2d(out),
nn.ReLU(inplace=True)
)
def forward(self,inp1,inp2,last_feature=None):
x = torch.cat([inp1,inp2],dim=1)
c1 = self.conv_1(x)
c2 = self.conv_2(x)
c3 = self.conv_3(x)
c4 = self.conv_4(x)
cat = torch.cat([c1,c2,c3,c4],dim=1)
fuse = self.fuse(cat)
inp1_siam = self.pre_siam(inp1)
inp2_siam = self.lat_siam(inp2)
inp1_mul = torch.mul(inp1_siam,fuse)
inp2_mul = torch.mul(inp2_siam,fuse)
fuse = self.fuse_siam(fuse)
if last_feature is None:
out = self.out(fuse + inp1 + inp2 + inp2_mul + inp1_mul)
else:
out = self.out(fuse+inp2_mul+inp1_mul+last_feature+inp1+inp2)
out = self.fuse_siam(out)
return out
if __name__ == '__main__':
block = BFAM(inp=128, out=256)
inp1 = torch.rand(1, 128 // 2, 16, 16) # B C H W
inp2 = torch.rand(1, 128 // 2, 16, 16)# B C H W
last_feature = torch.rand(1, 128 // 2, 16, 16)# B C H W
# 通过BFAM模块,这里没有提供last_feature的话,可以为None
output = block(inp1, inp2, last_feature)
# output = bfam(inp1, inp2)
# 打印输入和输出的shape
print(inp1.size())
print(inp2.size())
print(output.size())