From bf5cb0f3cdea0a141fa4ec13cd70b437e86c9f8a Mon Sep 17 00:00:00 2001 From: TayTroye <1582706091@qq.com> Date: Fri, 22 Mar 2024 14:33:02 +0800 Subject: [PATCH] fix fearec --- recbole/model/sequential_recommender/fearec.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/recbole/model/sequential_recommender/fearec.py b/recbole/model/sequential_recommender/fearec.py index f16385962..d44def528 100644 --- a/recbole/model/sequential_recommender/fearec.py +++ b/recbole/model/sequential_recommender/fearec.py @@ -210,13 +210,16 @@ def calculate_loss(self, interaction): lens = len(targets_index) if lens == 0: print("error") - while True: - sample_index = random.choice(targets_index) + remaining_indices = targets_index.copy() + while len(remaining_indices) > 0: + sample_index = random.choice(remaining_indices) + remaining_indices = remaining_indices[remaining_indices != sample_index] cur_item_list = interaction[self.ITEM_SEQ][i].to("cpu") sample_item_list = dataset.inter_feat[self.ITEM_SEQ][sample_index] are_equal = torch.equal(cur_item_list, sample_item_list) sample_item_length = dataset.inter_feat[self.ITEM_SEQ_LEN][sample_index] - if not are_equal or lens == 1: + + if not are_equal or len(remaining_indices) == 0: sem_pos_lengths.append(sample_item_length) sem_pos_seqs.append(sample_item_list) break