diff --git a/recbole/model/sequential_recommender/fearec.py b/recbole/model/sequential_recommender/fearec.py index f16385962..125fc7e43 100644 --- a/recbole/model/sequential_recommender/fearec.py +++ b/recbole/model/sequential_recommender/fearec.py @@ -210,13 +210,16 @@ def calculate_loss(self, interaction): lens = len(targets_index) if lens == 0: print("error") - while True: - sample_index = random.choice(targets_index) + remaining_indices = targets_index.copy() + while len(remaining_indices) > 0: + sample_index = random.choice(remaining_indices) + remaining_indices = remaining_indices[remaining_indices != sample_index] cur_item_list = interaction[self.ITEM_SEQ][i].to("cpu") sample_item_list = dataset.inter_feat[self.ITEM_SEQ][sample_index] are_equal = torch.equal(cur_item_list, sample_item_list) sample_item_length = dataset.inter_feat[self.ITEM_SEQ_LEN][sample_index] - if not are_equal or lens == 1: + + if not are_equal or len(remaining_indices) == 0: sem_pos_lengths.append(sample_item_length) sem_pos_seqs.append(sample_item_list) break