-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsimple-bisequencer-network-variable.lua
143 lines (116 loc) · 3.99 KB
/
simple-bisequencer-network-variable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
-- Example BLSTM for variable-length sequences
require 'rnn'
torch.manualSeed(0)
math.randomseed(0)
-- hyper-parameters
batchSize = 8
seqlen = 10 -- sequence length
hiddenSize = 5
nIndex = 10
lr = 0.1
maxIter = 100
local sharedLookupTable = nn.LookupTableMaskZero(nIndex, hiddenSize)
-- forward rnn
local fwd = nn.Sequential()
:add(sharedLookupTable)
:add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(true))
-- internally, rnn will be wrapped into a Recursor to make it an AbstractRecurrent instance.
fwdSeq = nn.Sequencer(fwd)
-- backward rnn (will be applied in reverse order of input sequence)
local bwd = nn.Sequential()
:add(sharedLookupTable:sharedClone())
:add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(true))
bwdSeq = nn.Sequencer(bwd)
-- merges the output of one time-step of fwd and bwd rnns.
-- You could also try nn.AddTable(), nn.Identity(), etc.
local merge = nn.JoinTable(1, 1)
mergeSeq = nn.Sequencer(merge)
-- Assume that two input sequences are given (original and reverse, both are right-padded).
-- Instead of ConcatTable, we use ParallelTable here.
local parallel = nn.ParallelTable()
parallel:add(fwdSeq):add(bwdSeq)
local brnn = nn.Sequential()
:add(parallel)
:add(nn.ZipTable())
:add(mergeSeq)
local rnn = nn.Sequential()
:add(brnn)
:add(nn.Sequencer(nn.MaskZero(nn.Linear(hiddenSize*2, nIndex), true))) -- times two due to JoinTable
:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), true)))
print(rnn)
-- build criterion
criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(), true))
-- build dummy dataset (task is to predict next item, given previous)
sequence_ = torch.LongTensor():range(1,10) -- 1,2,3,4,5,6,7,8,9,10
sequence = torch.LongTensor(100,10):copy(sequence_:view(1,10):expand(100,10))
sequence:resize(100*10) -- one long sequence of 1,2,3...,10,1,2,3...10...
offsets = {}
maxStep = {}
for i=1,batchSize do
table.insert(offsets, math.ceil(math.random()*sequence:size(1)))
-- variable length for each sample
table.insert(maxStep, math.random(seqlen))
end
offsets = torch.LongTensor(offsets)
-- training
for iteration = 1, maxIter do
-- 1. create a sequence of seqlen time-steps
local inputs, inputs_rev, targets = {}, {}, {}
for step=1,seqlen do
-- a batch of inputs
inputs[step] = sequence:index(1, offsets)
-- increment indices
offsets:add(1)
for j=1,batchSize do
if offsets[j] > sequence:size(1) then
offsets[j] = 1
end
end
targets[step] = sequence:index(1, offsets)
-- padding
for j=1,batchSize do
if step > maxStep[j] then
inputs[step][j] = 0
targets[step][j] = 0
end
end
end
-- reverse
for step=1,seqlen do
inputs_rev[step] = torch.LongTensor(batchSize)
for j=1,batchSize do
if step <= maxStep[j] then
inputs_rev[step][j] = inputs[maxStep[j]-step+1][j]
else
inputs_rev[step][j] = 0
end
end
end
-- 2. forward sequence through rnn
rnn:zeroGradParameters()
local outputs = rnn:forward({inputs, inputs_rev})
local err = criterion:forward(outputs, targets)
local correct = 0
local total = 0
for step=1,seqlen do
probs = outputs[step]
_, preds = probs:max(2)
for j=1,batchSize do
local cur_x = inputs[step][j]
local cur_y = targets[step][j]
local cur_t = preds[j][1]
-- print(string.format("x=%d ; y=%d ; pred=%d", cur_x, cur_y, cur_t))
if step <= maxStep[j] then
if cur_y == cur_t then correct = correct + 1 end
total = total + 1
end
end
end
local acc = correct*1.0/total
print(string.format("Iteration %d ; NLL err = %f ; ACC = %.2f ", iteration, err, acc))
-- 3. backward sequence through rnn (i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
local gradInputs = rnn:backward({inputs, inputs_rev}, gradOutputs)
-- 4. update
rnn:updateParameters(lr)
end