Skip to content

Commit

Permalink
Revised to filter row at to_csr()
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Dec 10, 2024
1 parent 97c5951 commit 37fb69c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 1 addition & 2 deletions rtrec/models/internal/slim_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,8 @@ def recommend(self, user_id: int, interaction_matrix: sp.csr_matrix, candidate_i
# Get the top-K items by sorting the predicted scores in descending order
# [::-1] reverses the order to get the items with the highest scores first
top_items = np.argsort(user_scores)[-top_k:][::-1]

return top_items

def similar_items(self, item_id: int, top_k: int=10):
"""
Get the top-K most similar items to a given item.
Expand Down
2 changes: 1 addition & 1 deletion rtrec/models/slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _recommend(self, user_id: int, candidate_item_ids: List[int], top_k: int = 1
:param filter_interacted: Whether to filter out items the user has already interacted with
:return: List of top-K item indices recommended for the user
"""
interaction_matrix = self.interactions.to_csr()
interaction_matrix = self.interactions.to_csr(select_users=[user_id])
return self.model.recommend(user_id, interaction_matrix, candidate_item_ids, top_k=top_k, filter_interacted=filter_interacted)

def _similar_items(self, query_item_id: int, top_k: int = 10) -> List[int]:
Expand Down
8 changes: 5 additions & 3 deletions rtrec/utils/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,19 @@ def get_all_non_negative_items(self, user_id: int) -> List[int]:
return [item_id for item_id in self.all_item_ids
if self.get_user_item_rating(user_id, item_id, default_rating=0.0) >= 0.0]

def to_csr(self) -> csr_matrix:
def to_csr(self, select_users: List[int] = None) -> csr_matrix:
rows, cols, data = [], [], []
max_row, max_col = 0, 0

for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
max_row = max(max_row, user)
max_col = max(max_col, item)
if select_users is not None and user not in select_users:
continue
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))
max_row = max(max_row, user)
max_col = max(max_col, item)

# Create the csr_matrix
return csr_matrix((data, (rows, cols)), shape=(max_row + 1, max_col + 1))
Expand Down

0 comments on commit 37fb69c

Please sign in to comment.