diff --git a/rtrec/recommender.py b/rtrec/recommender.py index 70796be..daf1598 100644 --- a/rtrec/recommender.py +++ b/rtrec/recommender.py @@ -35,6 +35,7 @@ def fit( batch_size: int = 1_000, update_interaction: bool = False, parallel: bool = False, + assume_sorted: bool = True ) -> Self: """ Fit the recommender model on the given DataFrame of interactions. @@ -44,6 +45,7 @@ def fit( :param batch_size (int): The number of interactions per mini-batch. Defaults to 1000. :param update_interaction (bool): Whether to update existing interactions. Defaults to False. :param parallel (bool): Whether to run the fitting process in parallel. Defaults to False. + :param assume_sorted (bool): Whether the interactions are already sorted by timestamp. Defaults to True. """ start_time = time.time() @@ -57,6 +59,9 @@ def fit( # Add interactions to the model interaction_df = train_data[["user", "item", "tstamp", "rating"]] + if not assume_sorted: + # sort interactions by timestamp ascending order + interaction_df.sort_values("tstamp", ascending=True, inplace=True) total = math.ceil(len(interaction_df) / batch_size) for batch in tqdm(generate_batches(interaction_df, batch_size, as_generator=self.use_generator), total=total, desc="Add interactions"): self.model.add_interactions(batch, update_interaction=update_interaction, record_interactions=True) @@ -74,7 +79,8 @@ def bulk_fit( item_tags: Optional[Dict[Any, List[str]]] = None, batch_size: int = 1_000, update_interaction: bool=False, - parallel: bool=True + parallel: bool=True, + assume_sorted: bool=True ) -> Self: """ Fit the recommender model on the given DataFrame of interactions in a single batch. @@ -84,6 +90,7 @@ def bulk_fit( :param batch_size (int): The number of interactions per mini-batch. Defaults to 1000. :param update_interaction (bool): Whether to update existing interactions. Defaults to False. :param parallel (bool): Whether to run the fitting process in parallel. Defaults to True. + :param assume_sorted (bool): Whether the interactions are already sorted by timestamp. Defaults to True. """ start_time = time.time() @@ -97,6 +104,9 @@ def bulk_fit( # Add interactions to the model interaction_df = train_data[["user", "item", "tstamp", "rating"]] + if not assume_sorted: + # sort interactions by timestamp ascending order inplace + interaction_df.sort_values("tstamp", ascending=True, inplace=True) total = math.ceil(len(interaction_df) / batch_size) for batch in tqdm(generate_batches(interaction_df, batch_size, as_generator=self.use_generator), total=total, desc="Add interactions"): self.model.add_interactions(batch, update_interaction=update_interaction)