-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathtf_attnhp.py
341 lines (268 loc) · 14.4 KB
/
tf_attnhp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from easy_tpp.model.tf_model.tf_baselayer import EncoderLayer
from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel
from easy_tpp.utils.tf_utils import get_shape_list
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.disable_v2_behavior()
class AttNHP(TfBaseModel):
def __init__(self, model_config):
super(AttNHP, self).__init__(model_config)
self.d_model = model_config.hidden_size
self.d_time = model_config.data_specs['time_emb_size']
self.use_norm = model_config.data_specs['use_ln']
self.n_layers = model_config.data_specs['num_layers']
self.n_head = model_config.data_specs['num_heads']
self.dropout = model_config.dropout_rate
# position vector, used for temporal encoding
self.div_term = self.make_div_term()
def make_div_term(self):
"""Initialize the division term used in temporal embedding.
Returns:
np.array: diviser.
"""
div_term_ = np.exp(np.arange(0, self.d_time, 2) * -(math.log(10000.0) / self.d_time))
div_term_ = np.reshape(div_term_, (1, 1, -1))
div_term = np.ones((1, 1, self.d_time))
for i in range(self.d_time):
div_term[..., i] = div_term_[..., i // 2]
return div_term
def build_graph(self):
"""Build up the network
"""
with tf.variable_scope('AttNHP'):
self.build_input_graph()
self.layer_intensity = layers.Dense(self.num_event_types, activation=tf.nn.softplus)
self.heads = []
for i in range(self.n_head):
self.heads.append([EncoderLayer(hidden_size=self.d_model,
num_heads=self.n_head,
dropout_rate=self.dropout) for _ in range(self.n_layers)])
self.loss, self.num_event = self.loglike_loss()
# Make predictions
if self.event_sampler and self.gen_config.num_step_gen == 1:
self.dtime_predict_one_step, self.type_predict_one_step = \
self.predict_one_step_at_every_event(self.time_seqs,
self.time_delta_seqs,
self.type_seqs)
if self.event_sampler and self.gen_config.num_step_gen > 1:
# make generations
self.dtime_generation, self.type_generation = \
self.predict_multi_step_since_last_event(self.time_seqs,
self.time_delta_seqs,
self.type_seqs,
num_step=self.gen_config.num_step_gen)
def compute_temporal_embedding(self, time_seqs):
"""Compute the temporal embedding.
Args:
time_seqs (tensor): [batch_size, seq_len].
Returns:
tensor: [batch_size, seq_len, emb_size].
"""
batch_size, seq_len = get_shape_list(time_seqs) # dynamic
# [self.d_time]
position_mask = np.array([1] * self.d_time)
position_mask[1::2] = 0 # dim 2i+1
position_mask = tf.convert_to_tensor(position_mask, tf.int32)
# [batch_size, max_len, d_time]
position_mask = tf.tile(position_mask[None, None, ...], [batch_size, seq_len, 1])
time_seqs_ = time_seqs[..., None]
position_enc = tf.where(tf.equal(position_mask, 0), # dim 2i+1
tf.cos(time_seqs_ * self.div_term), # dim 2i+1
tf.sin(time_seqs_ * self.div_term)) # dim 2i
# [batch_size, max_len, hidden_dim]
return position_enc
def seq_encoding(self, time_seqs, type_seqs):
"""Encode the sequence.
Args:
time_seqs (tensor): time seqs input, [batch_size, seq_len].
event_seqs (_type_): event type seqs input, [batch_size, seq_len].
Returns:
tuple: event embedding, time embedding and type embedding.
"""
# [batch_size, seq_len, hidden_size]
time_emb = self.compute_temporal_embedding(time_seqs)
# [batch_size, seq_len, hidden_size]
type_emb = tf.tanh(self.layer_type_emb(type_seqs))
# [batch_size, seq_len, hidden_size*2]
event_emb = tf.concat([type_emb, time_emb], axis=-1)
return event_emb, time_emb, type_emb
def make_layer_mask(self, attention_mask):
"""Create a tensor to do masking on layers.
Args:
attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len]
Returns:
tensor: aim to keep the current layer, the same size of attention mask
a diagonal matrix, [batch_size, seq_len, seq_len]
"""
# [batch_size, seq_len, seq_len]
layer_mask = tf.eye(get_shape_list(attention_mask)[1]) < 1
layer_mask = layer_mask[None, ...]
layer_mask = tf.tile(layer_mask, [get_shape_list(attention_mask)[0], 1, 1])
return tf.cast(layer_mask, tf.int32)
def make_combined_att_mask(self, attention_mask, layer_mask):
"""Combined attention mask and layer mask.
Args:
attention_mask (tensor): mask for attention operation, [batch_size, seq_len, seq_len]
layer_mask (tensor): mask for other layers, [batch_size, seq_len, seq_len]
Returns:
tensor: [batch_size, seq_len * 2, seq_len * 2]
"""
# [batch_size, seq_len, seq_len * 2]
combined_mask = tf.concat([attention_mask, layer_mask], axis=-1)
# [batch_size, seq_len, seq_len * 2]
contextual_mask = tf.concat([attention_mask, tf.ones_like(layer_mask)], axis=-1)
# [batch_size, seq_len * 2, seq_len * 2]
combined_mask = tf.concat([contextual_mask, combined_mask], axis=1)
return combined_mask
def forward_pass(self, init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask):
"""update the structure sequentially.
Args:
init_cur_layer (tensor): [batch_size, seq_len, hidden_size]
time_emb (tensor): [batch_size, seq_len, hidden_size]
sample_time_emb (tensor): [batch_size, seq_len, hidden_size]
event_emb (tensor): [batch_size, seq_len, hidden_size]
combined_mask (tensor): [batch_size, seq_len, hidden_size]
Returns:
tensor: [batch_size, seq_len, hidden_size*2]
"""
cur_layers = []
seq_len = get_shape_list(time_emb)[1]
for head_i in range(self.n_head):
# [batch_size, seq_len, hidden_size]
cur_layer_ = init_cur_layer
for layer_i in range(self.n_layers):
# each layer concats the temporal emb
# [batch_size, seq_len, hidden_size*2]
layer_ = tf.concat([cur_layer_, sample_time_emb], axis=-1)
# make combined input from event emb + layer emb
# [batch_size, seq_len*2, hidden_size*2]
_combined_input = tf.concat([event_emb, layer_], axis=1)
enc_layer = self.heads[head_i][layer_i]
# compute the output
enc_output = enc_layer((_combined_input, combined_mask),
training=self.is_training)
# the layer output
# [batch_size, seq_len, hidden_size]
_cur_layer_ = enc_output[:, seq_len:, :]
# add residual connection
cur_layer_ = tf.tanh(_cur_layer_) + cur_layer_
# event emb
event_emb = tf.concat([enc_output[:, :seq_len, :], time_emb], axis=-1)
if self.use_norm:
cur_layer_ = self.norm(cur_layer_)
cur_layers.append(cur_layer_)
cur_layer_ = tf.concat(cur_layers, axis=-1)
return cur_layer_
def forward(self, time_seqs, type_seqs, attention_mask, sample_times=None):
""" Move forward through the network """
# [batch_size, seq_len, hidden_size]
event_emb, time_emb, type_emb = self.seq_encoding(time_seqs, type_seqs)
init_cur_layer = tf.zeros_like(type_emb)
layer_mask = self.make_layer_mask(attention_mask)
if sample_times is None:
sample_time_emb = time_emb
else:
sample_time_emb = self.compute_temporal_embedding(sample_times)
combined_mask = self.make_combined_att_mask(attention_mask, layer_mask)
encoder_output = self.forward_pass(init_cur_layer, time_emb, sample_time_emb, event_emb, combined_mask)
return encoder_output
def loglike_loss(self):
# 1. compute event-loglik
enc_out = self.forward(self.time_seqs[:, :-1],
self.type_seqs[:, :-1],
self.attention_mask[:, 1:, :-1],
self.time_seqs[:, 1:])
# [batch_size, seq_len, num_event_types]
lambda_at_event = self.layer_intensity(enc_out)
# 2. compute non-event-loglik (using MC sampling to compute integral)
# 2.1 sample times
# [batch_size, seq_len, num_sample]
temp_time = self.make_dtime_loss_samples(self.time_delta_seqs[:, 1:])
# [batch_size, seq_len, num_sample]
sample_times = temp_time + self.time_seqs[:, :-1][..., None]
# 2.2 compute intensities at sampled times
# [batch_size, num_times = max_len - 1, num_sample, event_num]
lambda_t_sample = self.compute_intensities_at_sample_times(self.time_seqs[:, :-1],
self.time_delta_seqs[:, :-1], # not used
self.type_seqs[:, :-1],
sample_times,
attention_mask=self.attention_mask[:, 1:, :-1])
event_ll, non_event_ll, num_events = self.compute_loglikelihood(lambda_at_event=lambda_at_event,
lambdas_loss_samples=lambda_t_sample,
time_delta_seq=self.time_delta_seqs[:, 1:],
seq_mask=self.batch_non_pad_mask[:, 1:],
lambda_type_mask=self.type_mask[:, 1:])
# return enc_inten to compute accuracy
loss = - tf.reduce_sum(event_ll - non_event_ll)
return loss, num_events
def compute_states_at_sample_times(self,
time_seqs,
type_seqs,
attention_mask,
sample_times):
"""
Args:
time_seqs: [batch_size, seq_len]
type_seqs: [batch_size, seq_len]
attention_mask: [batch_size, seq_len, seq_len]
sample_times: [batch_size, seq_len, num_samples]
Returns:
hidden states at all sampled times: [batch_size, seq_len, num_samples, hidden_size]
"""
batch_size, seq_len = get_shape_list(type_seqs)
num_samples = get_shape_list(sample_times)[-1]
# [num_samples, batch_size, seq_len]
sample_times = tf.transpose(sample_times, perm=(2, 0, 1))
# [num_samples * batch_size, seq_len]
_sample_time = tf.reshape(sample_times, (num_samples * batch_size, -1))
# [num_samples * batch_size, seq_len]
_types = tf.reshape(tf.tile(type_seqs[None, ...], (num_samples, 1, 1)), (num_samples * batch_size, -1))
# [num_samples * batch_size, seq_len]
_times = tf.reshape(tf.tile(time_seqs[None, ...], (num_samples, 1, 1)), (num_samples * batch_size, -1))
# [num_samples * batch_size, seq_len]
_attn_mask = tf.tile(attention_mask[None, ...], (num_samples, 1, 1, 1))
_attn_mask = tf.reshape(_attn_mask, (num_samples * batch_size, seq_len, seq_len))
# [num_samples * batch_size, seq_len, hidden_size]
encoder_output = self.forward(_times,
_types,
_attn_mask,
_sample_time)
# [num_samples, batch_size, seq_len, hidden_size]
encoder_output = tf.reshape(encoder_output, (num_samples, batch_size, seq_len, -1))
# [batch_size, seq_len, num_samples, hidden_size]
encoder_output = tf.transpose(encoder_output, perm=(1, 2, 0, 3))
return encoder_output
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_times, **kwargs):
"""
Args:
time_seqs: [batch_size, seq_len]
time_delta_seqs: [batch_size, seq_len]
type_seqs: [batch_size, seq_len]
sample_times: [batch_size, seq_len, num_samples]
Returns:
intensities at sample times: [batch_size, seq_len, num_samples, num_event_types]
"""
attention_mask = kwargs.get('attention_mask', None)
compute_last_step_only = kwargs.get('compute_last_step_only', False)
if attention_mask is None:
batch_size, seq_len = get_shape_list(time_seqs)
attention_mask = tf.ones((seq_len, seq_len))
# only keep the strict upper triangular
lower_diag_masks = tf.linalg.LinearOperatorLowerTriangular(tf.ones_like(attention_mask)).to_dense()
attention_mask = tf.where(tf.equal(lower_diag_masks, 0),
attention_mask,
tf.zeros_like(attention_mask))
attention_mask = tf.tile(attention_mask[None, ...], (batch_size, 1, 1))
attention_mask = tf.cast(attention_mask, tf.int32)
# [batch_size, seq_len, num_samples, hidden_size]
encoder_output = self.compute_states_at_sample_times(time_seqs, type_seqs, attention_mask, sample_times)
if compute_last_step_only:
lambdas = self.layer_intensity(encoder_output[:, -1:, :, :])
else:
# [batch_size, seq_len, num_samples, num_event_types]
lambdas = self.layer_intensity(encoder_output)
return lambdas