-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtlstm.py
45 lines (42 loc) · 1.61 KB
/
tlstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
class TimeLSTM(nn.Module):
def __init__(self, input_size, hidden_size, cuda_flag=False, bidirectional=False):
# assumes that batch_first is always true
super(TimeLSTM, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.cuda_flag = cuda_flag
self.W_all = nn.Linear(hidden_size, hidden_size * 4)
self.U_all = nn.Linear(input_size, hidden_size * 4)
self.W_d = nn.Linear(hidden_size, hidden_size)
self.bidirectional = bidirectional
def forward(self, inputs, timestamps, lens, reverse=False):
# inputs: [b, seq, embed]
# h: [b, hid]
# c: [b, hid]
b, seq, embed = inputs.size()
h = torch.zeros(b, self.hidden_size, requires_grad=False)
c = torch.zeros(b, self.hidden_size, requires_grad=False)
if self.cuda_flag:
h = h.cuda()
c = c.cuda()
outputs = []
for s in range(seq):
c_s1 = torch.tanh(self.W_d(c))
c_s2 = c_s1 * timestamps[:, s:s + 1].expand_as(c_s1)
c_l = c - c_s1
c_adj = c_l + c_s2
outs = self.W_all(h) + self.U_all(inputs[:, s])
f, i, o, c_tmp = torch.chunk(outs, 4, 1)
f = torch.sigmoid(f)
i = torch.sigmoid(i)
o = torch.sigmoid(o)
c_tmp = torch.sigmoid(c_tmp)
c = f * c_adj + i * c_tmp
h = o * torch.tanh(c)
outputs.append(h)
if reverse:
outputs.reverse()
outputs = torch.stack(outputs, 1)
return outputs