-
Notifications
You must be signed in to change notification settings - Fork 236
/
AFPN.py
422 lines (322 loc) · 15.9 KB
/
AFPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
# 论文:AFPN: Asymptotic Feature Pyramid Network for Object Detection
# 论文地址:https://arxiv.org/pdf/2306.15988
# github地址:https://github.com/gyyang23/AFPN
def BasicConv(filter_in, filter_out, kernel_size, stride=1, pad=None):
if not pad:
pad = (kernel_size - 1) // 2 if kernel_size else 0
else:
pad = pad
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.ReLU(inplace=True)),
]))
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, filter_in, filter_out):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(filter_in, filter_out, 3, padding=1)
self.bn1 = nn.BatchNorm2d(filter_out, momentum=0.1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(filter_out, filter_out, 3, padding=1)
self.bn2 = nn.BatchNorm2d(filter_out, momentum=0.1)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu(out)
return out
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor=2):
super(Upsample, self).__init__()
self.upsample = nn.Sequential(
BasicConv(in_channels, out_channels, 1),
nn.Upsample(scale_factor=scale_factor, mode='bilinear')
)
# carafe
# from mmcv.ops import CARAFEPack
# self.upsample = nn.Sequential(
# BasicConv(in_channels, out_channels, 1),
# CARAFEPack(out_channels, scale_factor=scale_factor)
# )
def forward(self, x):
x = self.upsample(x)
return x
class Downsample_x2(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample_x2, self).__init__()
self.downsample = nn.Sequential(
BasicConv(in_channels, out_channels, 2, 2, 0)
)
def forward(self, x, ):
x = self.downsample(x)
return x
class Downsample_x4(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample_x4, self).__init__()
self.downsample = nn.Sequential(
BasicConv(in_channels, out_channels, 4, 4, 0)
)
def forward(self, x, ):
x = self.downsample(x)
return x
class Downsample_x8(nn.Module):
def __init__(self, in_channels, out_channels):
super(Downsample_x8, self).__init__()
self.downsample = nn.Sequential(
BasicConv(in_channels, out_channels, 8, 8, 0)
)
def forward(self, x, ):
x = self.downsample(x)
return x
class ASFF_2(nn.Module):
def __init__(self, inter_dim=512):
super(ASFF_2, self).__init__()
self.inter_dim = inter_dim
compress_c = 8
self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c * 2, 2, kernel_size=1, stride=1, padding=0)
self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
def forward(self, input1, input2):
level_1_weight_v = self.weight_level_1(input1)
level_2_weight_v = self.weight_level_2(input2)
levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
input2 * levels_weight[:, 1:2, :, :]
out = self.conv(fused_out_reduced)
return out
class ASFF_3(nn.Module):
def __init__(self, inter_dim=512):
super(ASFF_3, self).__init__()
self.inter_dim = inter_dim
compress_c = 8
self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c * 3, 3, kernel_size=1, stride=1, padding=0)
self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
def forward(self, input1, input2, input3):
level_1_weight_v = self.weight_level_1(input1)
level_2_weight_v = self.weight_level_2(input2)
level_3_weight_v = self.weight_level_3(input3)
levels_weight_v = torch.cat((level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = input1 * levels_weight[:, 0:1, :, :] + \
input2 * levels_weight[:, 1:2, :, :] + \
input3 * levels_weight[:, 2:, :, :]
out = self.conv(fused_out_reduced)
return out
class ASFF_4(nn.Module):
def __init__(self, inter_dim=512):
super(ASFF_4, self).__init__()
self.inter_dim = inter_dim
compress_c = 8
self.weight_level_0 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_level_3 = BasicConv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c * 4, 4, kernel_size=1, stride=1, padding=0)
self.conv = BasicConv(self.inter_dim, self.inter_dim, 3, 1)
def forward(self, input0, input1, input2, input3):
level_0_weight_v = self.weight_level_0(input0)
level_1_weight_v = self.weight_level_1(input1)
level_2_weight_v = self.weight_level_2(input2)
level_3_weight_v = self.weight_level_3(input3)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v), 1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = input0 * levels_weight[:, 0:1, :, :] + \
input1 * levels_weight[:, 1:2, :, :] + \
input2 * levels_weight[:, 2:3, :, :] + \
input3 * levels_weight[:, 3:, :, :]
out = self.conv(fused_out_reduced)
return out
class BlockBody(nn.Module):
def __init__(self, channels=[64, 128, 256, 512]):
super(BlockBody, self).__init__()
self.blocks_scalezero1 = nn.Sequential(
BasicConv(channels[0], channels[0], 1),
)
self.blocks_scaleone1 = nn.Sequential(
BasicConv(channels[1], channels[1], 1),
)
self.blocks_scaletwo1 = nn.Sequential(
BasicConv(channels[2], channels[2], 1),
)
self.blocks_scalethree1 = nn.Sequential(
BasicConv(channels[3], channels[3], 1),
)
self.downsample_scalezero1_2 = Downsample_x2(channels[0], channels[1])
self.upsample_scaleone1_2 = Upsample(channels[1], channels[0], scale_factor=2)
self.asff_scalezero1 = ASFF_2(inter_dim=channels[0])
self.asff_scaleone1 = ASFF_2(inter_dim=channels[1])
self.blocks_scalezero2 = nn.Sequential(
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
)
self.blocks_scaleone2 = nn.Sequential(
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
)
self.downsample_scalezero2_2 = Downsample_x2(channels[0], channels[1])
self.downsample_scalezero2_4 = Downsample_x4(channels[0], channels[2])
self.downsample_scaleone2_2 = Downsample_x2(channels[1], channels[2])
self.upsample_scaleone2_2 = Upsample(channels[1], channels[0], scale_factor=2)
self.upsample_scaletwo2_2 = Upsample(channels[2], channels[1], scale_factor=2)
self.upsample_scaletwo2_4 = Upsample(channels[2], channels[0], scale_factor=4)
self.asff_scalezero2 = ASFF_3(inter_dim=channels[0])
self.asff_scaleone2 = ASFF_3(inter_dim=channels[1])
self.asff_scaletwo2 = ASFF_3(inter_dim=channels[2])
self.blocks_scalezero3 = nn.Sequential(
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
)
self.blocks_scaleone3 = nn.Sequential(
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
)
self.blocks_scaletwo3 = nn.Sequential(
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
)
self.downsample_scalezero3_2 = Downsample_x2(channels[0], channels[1])
self.downsample_scalezero3_4 = Downsample_x4(channels[0], channels[2])
self.downsample_scalezero3_8 = Downsample_x8(channels[0], channels[3])
self.upsample_scaleone3_2 = Upsample(channels[1], channels[0], scale_factor=2)
self.downsample_scaleone3_2 = Downsample_x2(channels[1], channels[2])
self.downsample_scaleone3_4 = Downsample_x4(channels[1], channels[3])
self.upsample_scaletwo3_4 = Upsample(channels[2], channels[0], scale_factor=4)
self.upsample_scaletwo3_2 = Upsample(channels[2], channels[1], scale_factor=2)
self.downsample_scaletwo3_2 = Downsample_x2(channels[2], channels[3])
self.upsample_scalethree3_8 = Upsample(channels[3], channels[0], scale_factor=8)
self.upsample_scalethree3_4 = Upsample(channels[3], channels[1], scale_factor=4)
self.upsample_scalethree3_2 = Upsample(channels[3], channels[2], scale_factor=2)
self.asff_scalezero3 = ASFF_4(inter_dim=channels[0])
self.asff_scaleone3 = ASFF_4(inter_dim=channels[1])
self.asff_scaletwo3 = ASFF_4(inter_dim=channels[2])
self.asff_scalethree3 = ASFF_4(inter_dim=channels[3])
self.blocks_scalezero4 = nn.Sequential(
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
BasicBlock(channels[0], channels[0]),
)
self.blocks_scaleone4 = nn.Sequential(
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
BasicBlock(channels[1], channels[1]),
)
self.blocks_scaletwo4 = nn.Sequential(
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
BasicBlock(channels[2], channels[2]),
)
self.blocks_scalethree4 = nn.Sequential(
BasicBlock(channels[3], channels[3]),
BasicBlock(channels[3], channels[3]),
BasicBlock(channels[3], channels[3]),
BasicBlock(channels[3], channels[3]),
)
def forward(self, x):
x0, x1, x2, x3 = x
x0 = self.blocks_scalezero1(x0)
x1 = self.blocks_scaleone1(x1)
x2 = self.blocks_scaletwo1(x2)
x3 = self.blocks_scalethree1(x3)
scalezero = self.asff_scalezero1(x0, self.upsample_scaleone1_2(x1))
scaleone = self.asff_scaleone1(self.downsample_scalezero1_2(x0), x1)
x0 = self.blocks_scalezero2(scalezero)
x1 = self.blocks_scaleone2(scaleone)
scalezero = self.asff_scalezero2(x0, self.upsample_scaleone2_2(x1), self.upsample_scaletwo2_4(x2))
scaleone = self.asff_scaleone2(self.downsample_scalezero2_2(x0), x1, self.upsample_scaletwo2_2(x2))
scaletwo = self.asff_scaletwo2(self.downsample_scalezero2_4(x0), self.downsample_scaleone2_2(x1), x2)
x0 = self.blocks_scalezero3(scalezero)
x1 = self.blocks_scaleone3(scaleone)
x2 = self.blocks_scaletwo3(scaletwo)
scalezero = self.asff_scalezero3(x0, self.upsample_scaleone3_2(x1), self.upsample_scaletwo3_4(x2), self.upsample_scalethree3_8(x3))
scaleone = self.asff_scaleone3(self.downsample_scalezero3_2(x0), x1, self.upsample_scaletwo3_2(x2), self.upsample_scalethree3_4(x3))
scaletwo = self.asff_scaletwo3(self.downsample_scalezero3_4(x0), self.downsample_scaleone3_2(x1), x2, self.upsample_scalethree3_2(x3))
scalethree = self.asff_scalethree3(self.downsample_scalezero3_8(x0), self.downsample_scaleone3_4(x1), self.downsample_scaletwo3_2(x2), x3)
scalezero = self.blocks_scalezero4(scalezero)
scaleone = self.blocks_scaleone4(scaleone)
scaletwo = self.blocks_scaletwo4(scaletwo)
scalethree = self.blocks_scalethree4(scalethree)
return scalezero, scaleone, scaletwo, scalethree
class AFPN(nn.Module):
def __init__(self,
in_channels=[256, 512, 1024, 2048],
out_channels=256):
super(AFPN, self).__init__()
self.fp16_enabled = False
self.conv0 = BasicConv(in_channels[0], in_channels[0] // 8, 1)
self.conv1 = BasicConv(in_channels[1], in_channels[1] // 8, 1)
self.conv2 = BasicConv(in_channels[2], in_channels[2] // 8, 1)
self.conv3 = BasicConv(in_channels[3], in_channels[3] // 8, 1)
self.body = nn.Sequential(
BlockBody([in_channels[0] // 8, in_channels[1] // 8, in_channels[2] // 8, in_channels[3] // 8])
)
self.conv00 = BasicConv(in_channels[0] // 8, out_channels, 1)
self.conv11 = BasicConv(in_channels[1] // 8, out_channels, 1)
self.conv22 = BasicConv(in_channels[2] // 8, out_channels, 1)
self.conv33 = BasicConv(in_channels[3] // 8, out_channels, 1)
self.conv44 = nn.MaxPool2d(kernel_size=1, stride=2)
# init weight
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight, gain=0.02)
elif isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def forward(self, x):
x0, x1, x2, x3 = x
x0 = self.conv0(x0)
x1 = self.conv1(x1)
x2 = self.conv2(x2)
x3 = self.conv3(x3)
out0, out1, out2, out3 = self.body([x0, x1, x2, x3])
out0 = self.conv00(out0)
out1 = self.conv11(out1)
out2 = self.conv22(out2)
out3 = self.conv33(out3)
# out4 = self.conv44(out3)
return out0, out1, out2, out3
if __name__ == "__main__":
afpn = AFPN(in_channels=[256, 512, 1024, 2048], out_channels=256)
afpn.eval()
x0 = torch.randn(1, 256, 64, 64)
x1 = torch.randn(1, 512, 32, 32)
x2 = torch.randn(1, 1024, 16, 16)
x3 = torch.randn(1, 2048, 8, 8)
out0, out1, out2, out3 = afpn([x0, x1, x2, x3])
print(x0.size())
print(x1.size())
print(x2.size())
print(x3.size())
print(out0.size())
print(out1.size())
print(out2.size())
print(out3.size())