forked from Teaonly/trans-torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaffe.lua
70 lines (59 loc) · 2.25 KB
/
caffe.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local ffi = require('ffi')
local C = transTorch._C
local toLinear = function(tm, caffeNet, layerName)
--assert(tm.weight:type() == 'torch.FloatTensor')
local weight = tm.weight:cdata()
local bias = tm.bias:cdata()
C.writeCaffeLinearLayer(caffeNet[0], layerName, weight, bias)
end
local toConv = function(tm, caffeNet, layerName)
--assert(tm.weight:type() == 'torch.FloatTensor')
local weights = tm.weight:float():cdata()
local bias = tm.bias:float():cdata()
C.writeCaffeConvLayer(caffeNet[0], layerName, weights, bias)
end
local toBatchNorm = function(tm, caffeNet, layerName)
if ( tm.affine == true) then
assert(type(layerName) == 'table')
assert(#layerName == 2)
local weights = tm.weight:float():cdata()
local bias = tm.bias:float():cdata()
local mean = tm.running_mean:float():cdata()
local var = tm.running_var:float():cdata()
C.writeCaffeBNLayer(caffeNet[0], layerName[1], mean, var);
C.writeCaffeScaleLayer(caffeNet[0], layerName[2], weights, bias);
else
assert(type(layerName) == 'string')
local mean = tm.running_mean:float():cdata()
local var = tm.running_var:float():cdata()
C.writeCaffeBNLayer(caffeNet[0], layerName[0], mean, var);
end
end
transTorch.loadCaffe = function(prototxt_name, binary_name)
assert(type(prototxt_name) == 'string')
if ( binary_name ~= nil ) then
assert(type(binary_name) == 'string')
end
local net = ffi.new("void*[1]")
net[0] = C.loadCaffeNet(prototxt_name, binary_name)
return net
end
transTorch.releaseCaffe = function(net)
C.releaseCaffeNet(net[0]);
end
transTorch.writeCaffe = function(net, fileName)
C.saveCaffeNet(net[0], fileName);
end
transTorch.toCaffe = function(tmodel, caffeNet, layerName)
local mtype = torch.type(tmodel)
if ( mtype == 'nn.Linear' ) then
toLinear(tmodel, caffeNet, layerName)
elseif ( mtype == 'nn.BatchNormalization' or mtype == 'nn.SpatialBatchNormalization' ) then
toBatchNorm(tmodel, caffeNet, layerName)
elseif ( string.match(mtype, 'Convolution') ) then
toConv(tmodel, caffeNet, layerName)
else
print(" ##ERROR## unspported layer:" .. mtype)
assert(false)
end
end