Skip to content

Commit

Permalink
Added assume_sorted argument
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jan 31, 2025
1 parent e933309 commit 91cc3e2
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion rtrec/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def fit(
batch_size: int = 1_000,
update_interaction: bool = False,
parallel: bool = False,
assume_sorted: bool = True
) -> Self:
"""
Fit the recommender model on the given DataFrame of interactions.
Expand All @@ -44,6 +45,7 @@ def fit(
:param batch_size (int): The number of interactions per mini-batch. Defaults to 1000.
:param update_interaction (bool): Whether to update existing interactions. Defaults to False.
:param parallel (bool): Whether to run the fitting process in parallel. Defaults to False.
:param assume_sorted (bool): Whether the interactions are already sorted by timestamp. Defaults to True.
"""
start_time = time.time()

Expand All @@ -57,6 +59,9 @@ def fit(

# Add interactions to the model
interaction_df = train_data[["user", "item", "tstamp", "rating"]]
if not assume_sorted:
# sort interactions by timestamp ascending order
interaction_df.sort_values("tstamp", ascending=True, inplace=True)
total = math.ceil(len(interaction_df) / batch_size)
for batch in tqdm(generate_batches(interaction_df, batch_size, as_generator=self.use_generator), total=total, desc="Add interactions"):
self.model.add_interactions(batch, update_interaction=update_interaction, record_interactions=True)
Expand All @@ -74,7 +79,8 @@ def bulk_fit(
item_tags: Optional[Dict[Any, List[str]]] = None,
batch_size: int = 1_000,
update_interaction: bool=False,
parallel: bool=True
parallel: bool=True,
assume_sorted: bool=True
) -> Self:
"""
Fit the recommender model on the given DataFrame of interactions in a single batch.
Expand All @@ -84,6 +90,7 @@ def bulk_fit(
:param batch_size (int): The number of interactions per mini-batch. Defaults to 1000.
:param update_interaction (bool): Whether to update existing interactions. Defaults to False.
:param parallel (bool): Whether to run the fitting process in parallel. Defaults to True.
:param assume_sorted (bool): Whether the interactions are already sorted by timestamp. Defaults to True.
"""
start_time = time.time()

Expand All @@ -97,6 +104,9 @@ def bulk_fit(

# Add interactions to the model
interaction_df = train_data[["user", "item", "tstamp", "rating"]]
if not assume_sorted:
# sort interactions by timestamp ascending order inplace
interaction_df.sort_values("tstamp", ascending=True, inplace=True)
total = math.ceil(len(interaction_df) / batch_size)
for batch in tqdm(generate_batches(interaction_df, batch_size, as_generator=self.use_generator), total=total, desc="Add interactions"):
self.model.add_interactions(batch, update_interaction=update_interaction)
Expand Down

0 comments on commit 91cc3e2

Please sign in to comment.