From 01551fb7c7adb3553d33fa1aab0266d139a80aea Mon Sep 17 00:00:00 2001 From: iLampard Date: Mon, 16 Sep 2024 23:39:21 +0800 Subject: [PATCH] Add Feedforward layer into THP model --- easy_tpp/model/torch_model/torch_baselayer.py | 6 +++++- easy_tpp/model/torch_model/torch_thp.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/easy_tpp/model/torch_model/torch_baselayer.py b/easy_tpp/model/torch_model/torch_baselayer.py index 7dfd43b..01efab9 100644 --- a/easy_tpp/model/torch_model/torch_baselayer.py +++ b/easy_tpp/model/torch_model/torch_baselayer.py @@ -89,7 +89,11 @@ def forward(self, x, mask): else: return x else: - return self.self_attn(x, x, x, mask) + x = self.self_attn(x, x, x, mask) + if self.feed_forward is not None: + return self.feed_forward(x) + else: + return x class TimePositionalEncoding(nn.Module): diff --git a/easy_tpp/model/torch_model/torch_thp.py b/easy_tpp/model/torch_model/torch_thp.py index 61adfe7..3a01ba5 100644 --- a/easy_tpp/model/torch_model/torch_thp.py +++ b/easy_tpp/model/torch_model/torch_thp.py @@ -27,8 +27,8 @@ def __init__(self, model_config): self.layer_temporal_encoding = TimePositionalEncoding(self.d_model, device=self.device) - self.factor_intensity_base = torch.empty([1, self.num_event_types], device=self.device) - self.factor_intensity_decay = torch.empty([1, self.num_event_types], device=self.device) + self.factor_intensity_base = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) + self.factor_intensity_decay = nn.Parameter(torch.empty([1, self.num_event_types], device=self.device)) nn.init.xavier_normal_(self.factor_intensity_base) nn.init.xavier_normal_(self.factor_intensity_decay) @@ -36,6 +36,14 @@ def __init__(self, model_config): self.layer_intensity_hidden = nn.Linear(self.d_model, self.num_event_types) self.softplus = nn.Softplus() + # Add MLP layer + # Equation (5) + self.feed_forward = nn.Sequential( + nn.Linear(self.d_model, self.d_model * 2), + nn.ReLU(), + nn.Linear(self.d_model * 2, self.d_model) + ) + self.stack_layers = nn.ModuleList( [EncoderLayer( self.d_model, @@ -43,6 +51,7 @@ def __init__(self, model_config): output_linear=False), use_residual=False, + feed_forward=self.feed_forward, dropout=self.dropout ) for _ in range(self.n_layers)])