Skip to content

Commit

Permalink
undo the changes in RepeatNet
Browse files Browse the repository at this point in the history
  • Loading branch information
ken77921 committed Nov 13, 2023
1 parent dbe0f7a commit 4b72c3a
Showing 1 changed file with 50 additions and 42 deletions.
92 changes: 50 additions & 42 deletions recbole/model/sequential_recommender/repeatnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class RepeatNet(SequentialRecommender):
input_type = InputType.POINTWISE

def __init__(self, config, dataset):

super(RepeatNet, self).__init__(config, dataset)

# load the dataset information
Expand All @@ -50,24 +49,29 @@ def __init__(self, config, dataset):
self.dropout_prob = config["dropout_prob"]

# define the layers and loss function
self.item_matrix = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.item_matrix = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.gru = nn.GRU(self.embedding_size, self.hidden_size, batch_first=True)
self.repeat_explore_mechanism = Repeat_Explore_Mechanism(
self.device, hidden_size=self.hidden_size, seq_len=self.max_seq_length, dropout_prob=self.dropout_prob
self.device,
hidden_size=self.hidden_size,
seq_len=self.max_seq_length,
dropout_prob=self.dropout_prob,
)
self.repeat_recommendation_decoder = Repeat_Recommendation_Decoder(
self.device,
hidden_size=self.hidden_size,
seq_len=self.max_seq_length,
num_item=self.n_items,
dropout_prob=self.dropout_prob
dropout_prob=self.dropout_prob,
)
self.explore_recommendation_decoder = Explore_Recommendation_Decoder(
hidden_size=self.hidden_size,
seq_len=self.max_seq_length,
num_item=self.n_items,
device=self.device,
dropout_prob=self.dropout_prob
dropout_prob=self.dropout_prob,
)

self.loss_fct = F.nll_loss
Expand All @@ -76,7 +80,6 @@ def __init__(self, config, dataset):
self.apply(self._init_weights)

def _init_weights(self, module):

if isinstance(module, nn.Embedding):
xavier_normal_(module.weight.data)
elif isinstance(module, nn.Linear):
Expand All @@ -85,34 +88,45 @@ def _init_weights(self, module):
constant_(module.bias.data, 0)

def forward(self, item_seq, item_seq_len):

batch_seq_item_embedding = self.item_matrix(item_seq)
# batch_size * seq_len == embedding ==>> batch_size * seq_len * embedding_size

all_memory, _ = self.gru(batch_seq_item_embedding)
last_memory = self.gather_indexes(all_memory, item_seq_len - 1)
# all_memory: batch_size * item_seq * hidden_size
# last_memory: batch_size * hidden_size
timeline_mask = (item_seq == 0)
timeline_mask = item_seq == 0

self.repeat_explore = self.repeat_explore_mechanism.forward(all_memory=all_memory, last_memory=last_memory)
self.repeat_explore = self.repeat_explore_mechanism.forward(
all_memory=all_memory, last_memory=last_memory
)
# batch_size * 2
repeat_recommendation_decoder = self.repeat_recommendation_decoder.forward(
all_memory=all_memory, last_memory=last_memory, item_seq=item_seq, mask=timeline_mask
all_memory=all_memory,
last_memory=last_memory,
item_seq=item_seq,
mask=timeline_mask,
)
# batch_size * num_item
explore_recommendation_decoder = self.explore_recommendation_decoder.forward(
all_memory=all_memory, last_memory=last_memory, item_seq=item_seq, mask=timeline_mask
all_memory=all_memory,
last_memory=last_memory,
item_seq=item_seq,
mask=timeline_mask,
)
# batch_size * num_item
prediction = repeat_recommendation_decoder * self.repeat_explore[:, 0].unsqueeze(1) \
+ explore_recommendation_decoder * self.repeat_explore[:, 1].unsqueeze(1)
prediction = repeat_recommendation_decoder * self.repeat_explore[
:, 0
].unsqueeze(1) + explore_recommendation_decoder * self.repeat_explore[
:, 1
].unsqueeze(
1
)
# batch_size * num_item

return prediction

def calculate_loss(self, interaction):

item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
pos_item = interaction[self.POS_ITEM_ID]
Expand All @@ -124,30 +138,33 @@ def calculate_loss(self, interaction):
return loss

def repeat_explore_loss(self, item_seq, pos_item):

batch_size = item_seq.size(0)
repeat, explore = torch.zeros(batch_size).to(self.device), torch.ones(batch_size).to(self.device)
repeat, explore = torch.zeros(batch_size).to(self.device), torch.ones(
batch_size
).to(self.device)
index = 0
for seq_item_ex, pos_item_ex in zip(item_seq, pos_item):
if pos_item_ex in seq_item_ex:
repeat[index] = 1
explore[index] = 0
index += 1
repeat_loss = torch.mul(repeat.unsqueeze(1), torch.log(self.repeat_explore[:, 0] + 1e-8)).mean()
explore_loss = torch.mul(explore.unsqueeze(1), torch.log(self.repeat_explore[:, 1] + 1e-8)).mean()
repeat_loss = torch.mul(
repeat.unsqueeze(1), torch.log(self.repeat_explore[:, 0] + 1e-8)
).mean()
explore_loss = torch.mul(
explore.unsqueeze(1), torch.log(self.repeat_explore[:, 1] + 1e-8)
).mean()

return (-repeat_loss - explore_loss) / 2

def full_sort_predict(self, interaction):

item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
prediction = self.forward(item_seq, item_seq_len)

return prediction

def predict(self, interaction):

item_seq = interaction[self.ITEM_SEQ]
test_item = interaction[self.ITEM_ID]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
Expand All @@ -161,7 +178,6 @@ def predict(self, interaction):


class Repeat_Explore_Mechanism(nn.Module):

def __init__(self, device, hidden_size, seq_len, dropout_prob):
super(Repeat_Explore_Mechanism, self).__init__()
self.dropout = nn.Dropout(dropout_prob)
Expand Down Expand Up @@ -202,7 +218,6 @@ def forward(self, all_memory, last_memory):


class Repeat_Recommendation_Decoder(nn.Module):

def __init__(self, device, hidden_size, seq_len, num_item, dropout_prob):
super(Repeat_Recommendation_Decoder, self).__init__()
self.dropout = nn.Dropout(dropout_prob)
Expand Down Expand Up @@ -233,21 +248,16 @@ def forward(self, all_memory, last_memory, item_seq, mask=None):
output_er.masked_fill_(mask, -1e9)

output_er = nn.Softmax(dim=-1)(output_er)
output_er = output_er.unsqueeze(1)

batch_size, b_len = item_seq.size()
repeat_recommendation_decoder = torch.zeros([batch_size, self.num_item], device=self.device)
repeat_recommendation_decoder.scatter_add_(1, item_seq, output_er) #(bsz, vocab_size) <- (bsz, context_item) x (bsz, context_item)
#output_er = output_er.unsqueeze(1)
#map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
#output_er = torch.matmul(output_er, map_matrix).squeeze(1).to(self.device)

#repeat_recommendation_decoder = output_er.squeeze(1).to(self.device)
map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
output_er = torch.matmul(output_er, map_matrix).squeeze(1).to(self.device)
repeat_recommendation_decoder = output_er.squeeze(1).to(self.device)

return repeat_recommendation_decoder.to(self.device)


class Explore_Recommendation_Decoder(nn.Module):

def __init__(self, hidden_size, seq_len, num_item, device, dropout_prob):
super(Explore_Recommendation_Decoder, self).__init__()
self.dropout = nn.Dropout(dropout_prob)
Expand All @@ -259,7 +269,9 @@ def __init__(self, hidden_size, seq_len, num_item, device, dropout_prob):
self.Ue = nn.Linear(hidden_size, hidden_size)
self.tanh = nn.Tanh()
self.Ve = nn.Linear(hidden_size, 1)
self.matrix_for_explore = nn.Linear(2 * self.hidden_size, self.num_item, bias=False)
self.matrix_for_explore = nn.Linear(
2 * self.hidden_size, self.num_item, bias=False
)

def forward(self, all_memory, last_memory, item_seq, mask=None):
"""
Expand Down Expand Up @@ -287,15 +299,11 @@ def forward(self, all_memory, last_memory, item_seq, mask=None):
output_e = torch.cat([output_e, last_memory_values], dim=1)
output_e = self.dropout(self.matrix_for_explore(output_e))

#The modification comes from https://github.com/iesl/softmax_CPR_recommend
#map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
#explore_mask = torch.bmm((item_seq > 0).float().unsqueeze(1), map_matrix).squeeze(1)
#output_e = output_e.masked_fill(explore_mask.bool(), float('-inf'))
item_seq_first = item_seq[:,0].unsqueeze(1).expand_as(item_seq)
#item_seq_first[item_seq > 0] = 0 #make the padding 0 become the first item in the sequence
item_seq_first = item_seq_first.masked_fill(item_seq > 0, 0)
item_seq_first.requires_grad_(False)
output_e.scatter_add_(1, item_seq + item_seq_first, float('-inf') * torch.ones_like(item_seq))
map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
explore_mask = torch.bmm(
(item_seq > 0).float().unsqueeze(1), map_matrix
).squeeze(1)
output_e = output_e.masked_fill(explore_mask.bool(), float("-inf"))
explore_recommendation_decoder = nn.Softmax(1)(output_e)

return explore_recommendation_decoder
Expand Down Expand Up @@ -335,6 +343,6 @@ def build_map(b_map, device, max_index=None):
b_map_ = torch.FloatTensor(batch_size, b_len, max_index).fill_(0).to(device)
else:
b_map_ = torch.zeros(batch_size, b_len, max_index)
b_map_.scatter_(2, b_map.unsqueeze(2), 1.)
b_map_.scatter_(2, b_map.unsqueeze(2), 1.0)
b_map_.requires_grad = False
return b_map_

0 comments on commit 4b72c3a

Please sign in to comment.