-
Notifications
You must be signed in to change notification settings - Fork 3
/
ConvGRU.lua
129 lines (106 loc) · 4.01 KB
/
ConvGRU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
require 'cudnn'
require 'nn'
require 'extracunn'
require 'image'
local ConvGRU, parent = torch.class('nn.ConvGRU', 'nn.GRU')
function ConvGRU:__init(inputSize, outputSize, rho, kc, km, stride)
self.kc = kc
self.km = km
self.padc = torch.floor(kc/2)
self.padm = torch.floor(km/2)
self.stride = stride or 1
parent.__init(self, inputSize, outputSize, rho or 10)
end
function ConvGRU:buildGate()
-- Note : Input is : {input(t), output(t-1)}
local gate = nn.Sequential()
local input2gate = cudnn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc)
local output2gate = nn.SpatialConvolutionNoBias(self.outputSize, self.outputSize, self.km, self.km, self.stride, self.stride, self.padm, self.padm)
local para = nn.ParallelTable()
para:add(input2gate):add(output2gate)
gate:add(para)
gate:add(nn.CAddTable())
gate:add(cudnn.Sigmoid())
return gate
end
-------------------------- factory methods -----------------------------
function ConvGRU:buildModel()
-- input : {input, prevOutput}
-- output : {output}
self.inputGate = self:buildGate()
self.resetGate = self:buildGate()
local concat = nn.ConcatTable():add(nn.Identity()):add(self.inputGate):add(self.resetGate)
local seq = nn.Sequential()
seq:add(concat)
seq:add(nn.FlattenTable()) -- x(t), s(t-1), r, z
-- Rearrange to x(t), s(t-1), r, z, s(t-1)
local concat = nn.ConcatTable() --
concat:add(nn.NarrowTable(1,4)):add(nn.SelectTable(2))
seq:add(concat):add(nn.FlattenTable())
-- h
local t1 = nn.Sequential()
t1:add(nn.SelectTable(1))
local t2 = nn.Sequential()
t2:add(nn.NarrowTable(2,2)):add(nn.CMulTable())
t1:add(cudnn.SpatialConvolution(self.inputSize, self.outputSize, self.kc, self.kc, self.stride, self.stride, self.padc, self.padc))
t2:add(nn.SpatialConvolutionNoBias(self.outputSize, self.outputSize, self.km, self.km, self.stride, self.stride, self.padm, self.padm))
local concat = nn.ConcatTable()
concat:add(t1):add(t2)
local hidden = nn.Sequential()
hidden:add(concat):add(nn.CAddTable()):add(nn.Tanh())
-- 1-z
local z1 = nn.Sequential()
z1:add(nn.SelectTable(4))
z1:add(nn.SAdd(-1, true)) -- Scalar add & negation
-- z * h
local z2 = nn.Sequential()
z2:add(nn.NarrowTable(4,2))
z2:add(nn.CMulTable())
-- (1 - z) * h
local o1 = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(hidden):add(z1)
o1:add(concat):add(nn.CMulTable())
local o2 = nn.Sequential()
local concat = nn.ConcatTable()
concat:add(o1):add(z2)
o2:add(concat):add(nn.CAddTable())
seq:add(o2)
return seq
end
------------------------- forward backward -----------------------------
function ConvGRU:updateOutput(input)
local prevOutput
if self.step == 1 then
prevOutput = self.userPrevOutput or self.zeroTensor
self.zeroTensor:resize(self.outputSize, input:size(2), input:size(3)):zero()
else
-- previous output and memory of this module
prevOutput = self.output
end
-- output(t) = gru{input(t), output(t-1)}
local output
if self.train ~= false then
self:recycle()
local recurrentModule = self:getStepModule(self.step)
-- the actual forward propagation
output = recurrentModule:updateOutput{input, prevOutput}
else
output = self.recurrentModule:updateOutput{input, prevOutput}
end
self.outputs[self.step] = output
self.output = output
self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil
self.gradParametersAccumulated = false
-- note that we don't return the cell, just the output
return self.output
end
function ConvGRU:initBias(forgetBias, otherBias)
local oBias = otherBias or 0
local rBias = forgetBias or 1
self.inputGate.modules[1].modules[1].bias:fill(oBias)
self.resetGate.modules[1].modules[1].bias:fill(rBias)
end