-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathvgg_m_1024_torch.py
57 lines (50 loc) · 1.91 KB
/
vgg_m_1024_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.legacy.nn
from torch.autograd import Variable
from functools import reduce
from LRN import SpatialCrossMapLRN
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
vgg_m_1024_torch = nn.Sequential( # Sequential,
nn.Conv2d(3,96,(7, 7),(2, 2)),
nn.ReLU(),
Lambda(lambda x,lrn=SpatialCrossMapLRN(*(5, 0.0005, 0.75, 2)): (lrn.forward(x))),
nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True),
nn.Conv2d(96,256,(5, 5),(2, 2),(1, 1)),
nn.ReLU(),
Lambda(lambda x,lrn=SpatialCrossMapLRN(*(5, 0.0005, 0.75, 2)):(lrn.forward(x))),
nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True),
nn.Conv2d(256,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.Conv2d(512,512,(3, 3),(1, 1),(1, 1)),
nn.ReLU(),
nn.MaxPool2d((3, 3),(2, 2),(0, 0),ceil_mode=True),
Lambda(lambda x: x.view(x.size(0),-1)), # View,
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(18432,4096)), # Linear,
nn.ReLU(),
nn.Dropout(0.5),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(4096,1024)), # Linear,
nn.ReLU(),
nn.Dropout(0.5),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1024,1000)), # Linear,
nn.Softmax(),
)