-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathRNNPP.py
209 lines (159 loc) · 10.2 KB
/
RNNPP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import numpy as np
import matplotlib,sys
matplotlib.use('agg')
import tensorflow as tf
from BatchIterator import PaddedDataIterator
from generation import *
##############################################################################
# parameters
BATCH_SIZE = 256 # Batch size
MAX_STEPS = 300 # maximum length of your sequence
ITERS = 30000 # how many generator iterations to train for
REG = 0.1 # tradeoff between time and mark loss
LR = 1e-4 # learning rate
TYPE = sys.argv[1] # model type: joint event timeseries
NUM_steps_timeseries = 7 # timeseries steps before one event
Timeseries_feature = 4 # time series feature size
SEED = 12345 # set graph-level seed to make the random sequences generated by all ops be repeatable across sessions
tf.set_random_seed(SEED)
np.random.seed(SEED)
##############################################################################
# prepare data
#this is just toy data to test the code.
DIM_SIZE = 7 # equal num of classes
mi = MarkedIntensityHomogenuosPoisson(DIM_SIZE)
for u in range(DIM_SIZE):
mi.initialize(1.0, u)
simulated_sequences = generate_samples_marked(mi, 15.0, 1000)
event_iterator = PaddedDataIterator(simulated_sequences,0,MARK=True,DIFF=True)
# time_series_data = np.ones((BATCH_SIZE,real_batch[0].shape[1],NUM_steps_timeseries,4))
###############################################################################
# define model
def RNNPP(rnn_inputs_event, #dims batch_size x num_steps x input_size(mark&time), pad with 0 if variable length
rnn_inputs_timeseries, # batch_size x num_steps x num_steps_timeseries x input_size, pad with 0 if variable length
seqlen, # sequence length for each sequence, tf.int32 vector
lower_triangular_ones, # lower triangular matrix
num_classes=7, # number of dimensions for event sequence
loss='mse', # loss type for time: mse and intensity, intensity loss comes from Du, etc. KDD16
start = 3, # predict forward event starting at start-th event for each sequence
reg = REG, # loss trade-off between mark and time
state_size_event = 16, # RNN state size for event sequence
state_size_timeseries = 32, #RNN state size for time series
batch_size = BATCH_SIZE,
scope_reuse=False):
epilson = tf.constant(1e-3,tf.float32)
with tf.variable_scope("RNNPP") as scope:
if scope_reuse:
scope.reuse_variables()
num_steps = tf.shape(rnn_inputs_event)[1]
event_size = tf.shape(rnn_inputs_event)[2]
y = tf.concat([rnn_inputs_event[:,1:,:],rnn_inputs_event[:,:1,:]],axis=1)
y = tf.reshape(y, [-1,event_size])
if TYPE=='joint' or TYPE=='event':
with tf.variable_scope("event") as scope:
# rnn for event sequence
rnn_input_onehot = tf.one_hot(tf.cast(rnn_inputs_event[:,:,0],tf.int32),num_classes) #when num_class is large, use tf embedding
rnn_inputs_event = tf.concat([rnn_input_onehot,rnn_inputs_event[:,:,1:]],axis=2)
cell = tf.contrib.rnn.BasicRNNCell(state_size_event) #cell = tf.contrib.rnn.LSTMCell(state_size,state_is_tuple=True) is perferred
init_state = cell.zero_state(batch_size, tf.float32)
rnn_outputs_event, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs_event, sequence_length=seqlen, initial_state=init_state)
if TYPE=='joint' or TYPE=='timeseries':
with tf.variable_scope("timeseries") as scope:
# rnn for time series
cell = tf.contrib.rnn.BasicRNNCell(state_size_timeseries) #cell = tf.contrib.rnn.LSTMCell(state_size,state_is_tuple=True) is perferred
init_state = cell.zero_state(batch_size*num_steps, tf.float32)
rnn_inputs_timeseries = tf.reshape(rnn_inputs_timeseries,[-1,NUM_steps_timeseries,Timeseries_feature])
rnn_outputs_timeseries, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs_timeseries, initial_state=init_state)
rnn_outputs_timeseries = tf.reshape(rnn_outputs_timeseries[:,-1,:],[batch_size,num_steps,state_size_timeseries])
#reshape rnn_outputs
if TYPE=='joint':
rnn_outputs = tf.concat([rnn_outputs_event,rnn_outputs_timeseries],axis=2)
elif TYPE=='event':
rnn_outputs = rnn_outputs_event
elif TYPE=='timeseries':
rnn_outputs = rnn_outputs_timeseries
rnn_outputs_shape = tf.shape(rnn_outputs)
rnn_outputs = tf.reshape(rnn_outputs, [-1, rnn_outputs_shape[-1]])
# linear layer
with tf.variable_scope('time'):
if TYPE=='joint':
W_t = tf.get_variable('Wt', [state_size_event+state_size_timeseries, 1])
elif TYPE=='event':
W_t = tf.get_variable('Wt', [state_size_event, 1])
elif TYPE=='timeseries':
W_t = tf.get_variable('Wt', [state_size_timeseries, 1])
w_t = tf.get_variable('wt', [1], initializer=tf.constant_initializer(1.0))
b_t = tf.get_variable('bt', [1], initializer=tf.constant_initializer(0.0))
if loss=='intensity':
wt = tf.cond(tf.less(tf.reshape(tf.abs(w_t),[]),epilson),lambda:tf.sign(wt)*epilson, lambda:wt) #put in wrong position before
part1 = tf.matmul(rnn_outputs, W_t) + b_t
part2 = wt*y[:,1]
time_loglike = part1 + part2 + (tf.exp(part1)-tf.exp(part1+part2))/wt
time_loss = - time_loglike
elif loss=='mse':
time_hat = tf.matmul(rnn_outputs, W_t) + b_t
time_loss = tf.abs(tf.reshape(time_hat,[-1]) - y[:,1])
# Softmax layer
with tf.variable_scope('softmax'):
if TYPE=='joint':
W_l = tf.get_variable('Wl', [state_size_event+state_size_timeseries, num_classes])
elif TYPE=='event':
W_l = tf.get_variable('Wl', [state_size_event, num_classes])
elif TYPE=='timeseries':
W_l = tf.get_variable('Wl', [state_size_timeseries, num_classes])
b_l = tf.get_variable('bl', [num_classes], initializer=tf.constant_initializer(0.0))
mark_logits = tf.matmul(rnn_outputs, W_l) + b_l
mark_true = tf.one_hot(tf.cast(y[:,0],tf.int32),num_classes)
mark_loss = tf.nn.softmax_cross_entropy_with_logits(logits=mark_logits, labels=mark_true)
total_loss = mark_loss + reg*time_loss
#length of y minus 2 to drop last prediction
seqlen_mask = tf.slice(tf.gather(lower_triangular_ones, seqlen - 2), [0, start], [batch_size, num_steps - start])
zeros_pad = tf.zeros([batch_size,start])
seqlen_mask = tf.concat([zeros_pad,seqlen_mask],axis=1)
mark_loss = tf.reshape(mark_loss,[batch_size,num_steps])
mark_loss *= seqlen_mask
# Average over actual sequence lengths.
mark_loss = tf.reduce_sum(mark_loss, axis=1)
mark_loss = tf.reduce_mean(mark_loss)
total_loss = tf.reshape(total_loss,[batch_size,num_steps])
total_loss *= seqlen_mask #why 256*256 vs 256*140
# Average over actual sequence lengths.
total_loss = tf.reduce_sum(total_loss, axis=1)
total_loss = tf.reduce_mean(total_loss)
time_loss = total_loss - mark_loss
return total_loss,mark_loss,time_loss
event_sequence = tf.placeholder(tf.float32, shape=[BATCH_SIZE, None, 2])
time_series = tf.placeholder(tf.float32, shape=[BATCH_SIZE, None, NUM_steps_timeseries, Timeseries_feature])
seqlen = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
lower_triangular_ones = tf.constant(np.tril(np.ones([MAX_STEPS,MAX_STEPS])),dtype=tf.float32)
total_loss,mark_loss,time_loss = RNNPP(event_sequence,time_series,seqlen,lower_triangular_ones)
train_variables = tf.trainable_variables()
joint_variables = [v for v in train_variables if v.name.startswith("RNNPP")]
print(map(lambda x: x.op.name, joint_variables))
train_op = tf.train.RMSPropOptimizer(learning_rate=LR).minimize(total_loss, var_list=joint_variables)
##################################################################################
# run
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0, allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options))
sess.run(tf.global_variables_initializer())
# train
for it in range(ITERS):
if TYPE=='joint':
real_batch = event_iterator.next_batch(BATCH_SIZE)
time_series_data = np.ones((BATCH_SIZE,real_batch[0].shape[1],NUM_steps_timeseries,4))
total_loss_curr,mark_loss_curr,time_loss_curr, _= sess.run([total_loss,mark_loss,time_loss,train_op],
feed_dict={event_sequence:real_batch[0], seqlen:real_batch[1], time_series:time_series_data})
print ('Iter: {}; Total loss: {:.4}; Mark loss: {:.4}; Time loss: {:.4}'.format(it, total_loss_curr,mark_loss_curr,time_loss_curr))
if TYPE=='event':
real_batch = event_iterator.next_batch(BATCH_SIZE)
total_loss_curr,mark_loss_curr,time_loss_curr, _= sess.run([total_loss,mark_loss,time_loss,train_op],
feed_dict={event_sequence:real_batch[0], seqlen:real_batch[1]})
print ('Iter: {}; Total loss: {:.4}; Mark loss: {:.4}; Time loss: {:.4}'.format(it, total_loss_curr,mark_loss_curr,time_loss_curr))
if TYPE=='timeseries':
real_batch = event_iterator.next_batch(BATCH_SIZE)
time_series_data = np.ones((BATCH_SIZE,real_batch[0].shape[1],NUM_steps_timeseries,4))
total_loss_curr,mark_loss_curr,time_loss_curr, _= sess.run([total_loss,mark_loss,time_loss,train_op],
feed_dict={event_sequence:real_batch[0], seqlen:real_batch[1], time_series:time_series_data})
print ('Iter: {}; Total loss: {:.4}; Mark loss: {:.4}; Time loss: {:.4}'.format(it, total_loss_curr,mark_loss_curr,time_loss_curr))