-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtorch_rmtpp.py
117 lines (91 loc) · 4.95 KB
/
torch_rmtpp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from torch import nn
import math
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
class RMTPP(TorchBaseModel):
"""Torch implementation of Recurrent Marked Temporal Point Processes, KDD 2016.
https://www.kdd.org/kdd2016/papers/files/rpp1081-duA.pdf
"""
def __init__(self, model_config):
"""Initialize the model
Args:
model_config (EasyTPP.ModelConfig): config of model specs.
"""
super(RMTPP, self).__init__(model_config)
self.layer_temporal_emb = nn.Linear(1, self.hidden_size)
self.layer_rnn = nn.RNN(input_size=self.hidden_size, hidden_size=self.hidden_size,
num_layers=1, batch_first=True)
self.hidden_to_intensity_logits = nn.Linear(self.hidden_size, self.num_event_types)
self.b_t = nn.Parameter(torch.zeros(1, self.num_event_types))
self.w_t = nn.Parameter(torch.zeros(1, self.num_event_types))
nn.init.xavier_normal_(self.b_t)
nn.init.xavier_normal_(self.w_t)
def evolve_and_get_intentsity(self, right_hiddens_BNH, dts_BNG):
"""
Eq.11 that computes intensity.
"""
past_influence_BNGM = self.hidden_to_intensity_logits(right_hiddens_BNH[..., None, :])
intensity_BNGM = (past_influence_BNGM + self.w_t[None, None, :] * dts_BNG[..., None]
+ self.b_t[None, None, :]).clamp(max=math.log(1e5)).exp()
return intensity_BNGM
def forward(self, batch):
"""
Suppose we have inputs with original sequence length N+1
ts: [t0, t1, ..., t_N]
dts: [0, t1 - t0, t2 - t1, ..., t_N - t_{N-1}]
marks: [k0, k1, ..., k_N] (k0 and kN could be padded marks if t0 and tN correspond to left and right windows)
Return:
left limits of *intensity* at [t_1, ..., t_N] of shape: (batch_size, seq_len - 1, hidden_dim)
right limits of *hidden states* [t_0, ..., t_{N-1}, t_N] of shape: (batch_size, seq_len, hidden_dim)
We need the right limit of t_N to sample continuation.
"""
t_BN, dt_BN, marks_BN, _, _ = batch
mark_emb_BNH = self.layer_type_emb(marks_BN)
time_emb_BNH = self.layer_temporal_emb(t_BN[..., None])
right_hiddens_BNH, _ = self.layer_rnn(mark_emb_BNH + time_emb_BNH)
left_intensity_B_Nm1_M = self.evolve_and_get_intentsity(right_hiddens_BNH[:, :-1, :], dt_BN[:, 1:][...,None]).squeeze(-2)
return left_intensity_B_Nm1_M, right_hiddens_BNH
def loglike_loss(self, batch):
"""Compute the log-likelihood loss.
Args:
batch (list): batch input.
Returns:
tuple: loglikelihood loss and num of events.
"""
ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch
# compute left intensity and hidden states at event time
# left limits of intensity at [t_1, ..., t_N]
# right limits of hidden states at [t_0, ..., t_{N-1}, t_N]
left_intensity_B_Nm1_M, right_hiddens_BNH = self.forward((ts_BN, dts_BN, marks_BN, None, None))
right_hiddens_B_Nm1_H = right_hiddens_BNH[..., :-1, :] # discard right limit at t_N for logL
dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:])
intensity_dts_B_Nm1_G_M = self.evolve_and_get_intentsity(right_hiddens_B_Nm1_H, dts_sample_B_Nm1_G)
event_ll, non_event_ll, num_events = self.compute_loglikelihood(
lambda_at_event=left_intensity_B_Nm1_M,
lambdas_loss_samples=intensity_dts_B_Nm1_G_M,
time_delta_seq=dts_BN[:, 1:],
seq_mask=batch_non_pad_mask[:, 1:],
type_seq=marks_BN[:, 1:]
)
# compute loss to minimize
loss = - (event_ll - non_event_ll).sum()
return loss, num_events
def compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs):
"""Compute the intensity at sampled times, not only event times.
Args:
time_seq (tensor): [batch_size, seq_len], times seqs.
time_delta_seq (tensor): [batch_size, seq_len], time delta seqs.
event_seq (tensor): [batch_size, seq_len], event type seqs.
sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps.
Returns:
tensor: [batch_size, num_times, num_mc_sample, num_event_types],
intensity at each timestamp for each event type.
"""
compute_last_step_only = kwargs.get('compute_last_step_only', False)
_input = time_seqs, time_delta_seqs, type_seqs, None, None
_, right_hiddens_BNH = self.forward(_input)
if compute_last_step_only:
sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH[:, -1:, :], sample_dtimes[:, -1:, :])
else:
sampled_intensities = self.evolve_and_get_intentsity(right_hiddens_BNH, sample_dtimes) # shape: [B, N, G, M]
return sampled_intensities