Skip to content

Commit

Permalink
Revised to_csr() for efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Dec 11, 2024
1 parent 8038566 commit 63600a3
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions rtrec/utils/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, min_value: int = -5, max_value: int = 10, decay_in_days: Opti
# Half-life decay in time: decay_rate = 1 - ln(2) / decay_in_days
# https://dl.acm.org/doi/10.1145/1099554.1099689
self.decay_rate = 1.0 - (math.log(2) / decay_in_days)
self.max_user_id = 0
self.max_item_id = 0

def get_decay_rate(self) -> Optional[float]:
"""
Expand Down Expand Up @@ -88,6 +90,8 @@ def add_interaction(self, user_id: int, item_id: int, tstamp: float, delta: floa
# Store the updated value with the current timestamp
self.interactions[user_id][item_id] = (new_value, tstamp)
self.all_item_ids.add(item_id)
self.max_user_id = max(self.max_user_id, user_id)
self.max_item_id = max(self.max_item_id, item_id)

def get_user_item_rating(self, user_id: int, item_id: int, default_rating: float = 0.0) -> float:
"""
Expand Down Expand Up @@ -181,34 +185,34 @@ def get_all_non_negative_items(self, user_id: int) -> List[int]:

def to_csr(self, select_users: List[int] = None) -> csr_matrix:
rows, cols, data = [], [], []
max_row, max_col = 0, 0

for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
max_row = max(max_row, user)
max_col = max(max_col, item)
if select_users is not None and user not in select_users:
continue
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))
if select_users is None:
for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))

else:
for user in select_users:
for item, (rating, tstamp) in self.interactions.get(user, {}).items():
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))

# Create the csr_matrix
return csr_matrix((data, (rows, cols)), shape=(max_row + 1, max_col + 1))
return csr_matrix((data, (rows, cols)), shape=(self.max_user_id, self.max_item_id))

def to_csc(self, select_items: List[int] = None) -> csc_matrix:
rows, cols, data = [], [], []
max_row, max_col = 0, 0

for user, inner_dict in self.interactions.items():
for item, (rating, tstamp) in inner_dict.items():
max_row = max(max_row, user)
max_col = max(max_col, item)
if select_items is not None and item not in select_items:
continue
rows.append(user)
cols.append(item)
data.append(self._apply_decay(rating, tstamp))

# Create the csc_matrix
return csc_matrix((data, (rows, cols)), shape=(max_row + 1, max_col + 1))
return csc_matrix((data, (rows, cols)), shape=(self.max_user_id, self.max_item_id))

0 comments on commit 63600a3

Please sign in to comment.