-
Notifications
You must be signed in to change notification settings - Fork 77
/
model.py
91 lines (80 loc) · 3.02 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class Resnet50FPN(nn.Module):
def __init__(self):
super(Resnet50FPN, self).__init__()
self.resnet = torchvision.models.resnet50(pretrained=True)
children = list(self.resnet.children())
self.conv1 = nn.Sequential(*children[:4])
self.conv2 = children[4]
self.conv3 = children[5]
self.conv4 = children[6]
def forward(self, im_data):
feat = OrderedDict()
feat_map = self.conv1(im_data)
feat_map = self.conv2(feat_map)
feat_map3 = self.conv3(feat_map)
feat_map4 = self.conv4(feat_map3)
feat['map3'] = feat_map3
feat['map4'] = feat_map4
return feat
class CountRegressor(nn.Module):
def __init__(self, input_channels,pool='mean'):
super(CountRegressor, self).__init__()
self.pool = pool
self.regressor = nn.Sequential(
nn.Conv2d(input_channels, 196, 7, padding=3),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(196, 128, 5, padding=2),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(64, 32, 1),
nn.ReLU(),
nn.Conv2d(32, 1, 1),
nn.ReLU(),
)
def forward(self, im):
num_sample = im.shape[0]
if num_sample == 1:
output = self.regressor(im.squeeze(0))
if self.pool == 'mean':
output = torch.mean(output, dim=(0),keepdim=True)
return output
elif self.pool == 'max':
output, _ = torch.max(output, 0,keepdim=True)
return output
else:
for i in range(0,num_sample):
output = self.regressor(im[i])
if self.pool == 'mean':
output = torch.mean(output, dim=(0),keepdim=True)
elif self.pool == 'max':
output, _ = torch.max(output, 0,keepdim=True)
if i == 0:
Output = output
else:
Output = torch.cat((Output,output),dim=0)
return Output
def weights_normal_init(model, dev=0.01):
if isinstance(model, list):
for m in model:
weights_normal_init(m, dev)
else:
for m in model.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, dev)
if m.bias is not None:
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, dev)
def weights_xavier_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
torch.nn.init.zeros_(m.bias)