-
Notifications
You must be signed in to change notification settings - Fork 236
/
MDCR.py
135 lines (119 loc) · 4.33 KB
/
MDCR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
#论文地址:https://arxiv.org/pdf/2403.10778v1.pdf
#论文:HCF-Net: Hierarchical Context Fusion Network for Infrared Small Object Detection
class conv_block(nn.Module):
def __init__(self,
in_features,
out_features,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
norm_type='bn',
activation=True,
use_bias=True,
groups = 1
):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=use_bias,
groups = groups)
self.norm_type = norm_type
self.act = activation
if self.norm_type == 'gn':
self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
if self.norm_type == 'bn':
self.norm = nn.BatchNorm2d(out_features)
if self.act:
# self.relu = nn.GELU()
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.conv(x)
if self.norm_type is not None:
x = self.norm(x)
if self.act:
x = self.relu(x)
return x
class MDCR(nn.Module):
def __init__(self, in_features, out_features, norm_type='bn', activation=True, rate=[1, 6, 12, 18]):
super().__init__()
self.block1 = conv_block(
in_features=in_features//4,
out_features=out_features//4,
padding=rate[0],
dilation=rate[0],
norm_type=norm_type,
activation=activation,
groups=in_features // 4
)
self.block2 = conv_block(
in_features=in_features//4,
out_features=out_features//4,
padding=rate[1],
dilation=rate[1],
norm_type=norm_type,
activation=activation,
groups=in_features // 4
)
self.block3 = conv_block(
in_features=in_features//4,
out_features=out_features//4,
padding=rate[2],
dilation=rate[2],
norm_type=norm_type,
activation=activation,
groups=in_features // 4
)
self.block4 = conv_block(
in_features=in_features//4,
out_features=out_features//4,
padding=rate[3],
dilation=rate[3],
norm_type=norm_type,
activation=activation,
groups=in_features // 4
)
self.out_s = conv_block(
in_features=4,
out_features=4,
kernel_size=(1, 1),
padding=(0, 0),
norm_type=norm_type,
activation=activation,
)
self.out = conv_block(
in_features=out_features,
out_features=out_features,
kernel_size=(1, 1),
padding=(0, 0),
norm_type=norm_type,
activation=activation,
)
def forward(self, x):
split_tensors = []
x = torch.chunk(x, 4, dim=1)
x1 = self.block1(x[0])
x2 = self.block2(x[1])
x3 = self.block3(x[2])
x4 = self.block4(x[3])
for channel in range(x1.size(1)):
channel_tensors = [tensor[:, channel:channel + 1, :, :] for tensor in [x1, x2, x3, x4]]
concatenated_channel = self.out_s(torch.cat(channel_tensors, dim=1)) # 拼接在 batch_size 维度上
split_tensors.append(concatenated_channel)
x = torch.cat(split_tensors, dim=1)
x = self.out(x)
return x
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32) # B C H W
block = MDCR(in_features=64, out_features=64)
# 将输入张量传递给 MDCR 模块并获取输出
output = block(input)
# 打印输入和输出的形状
print(input.size())
print(output.size())