Skip to content

Commit

Permalink
Fixed the way to handle user_tags and item_tags
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jan 9, 2025
1 parent 0b60efa commit a61973e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
12 changes: 6 additions & 6 deletions rtrec/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def __init__(self, **kwargs: Any):

self.feature_store = FeatureStore()

def register_user_feature(self, user_id: Any, user_tags: List[str]) -> int:
def register_user_feature(self, user: Any, user_tags: List[str]) -> int:
"""
Register user features in the feature store.
:param user_id: User identifier
:param user: User to register features for
:param user_tags: List of user features
:return: User index
"""
user_id = self.user_ids.identify(user_id)
user_id = self.user_ids.identify(user)
self.feature_store.put_user_feature(user_id, user_tags)
return user_id

Expand All @@ -41,14 +41,14 @@ def clear_user_features(self, user_ids: Optional[List[int]] = None) -> None:
"""
self.feature_store.clear_user_features(user_ids)

def register_item_feature(self, item_id: Any, item_tags: List[str]) -> int:
def register_item_feature(self, item: Any, item_tags: List[str]) -> int:
"""
Register item features in the feature store.
:param item_id: Item identifier
:param item_id: Item to register features for
:param item_tags: List of item features
:return: Item index
"""
item_id = self.item_ids.identify(item_id)
item_id = self.item_ids.identify(item)
self.feature_store.put_item_feature(item_id, item_tags)
return item_id

Expand Down
40 changes: 25 additions & 15 deletions rtrec/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,29 @@ def partial_fit(self, user_interactions: Iterable[Tuple[int, int, int, float]],
def fit(
self,
train_data: pd.DataFrame,
user_tags: Optional[Dict[Any, List[str]]] = None,
item_tags: Optional[Dict[Any, List[str]]] = None,
batch_size: int = 1_000,
update_interaction: bool = False,
parallel: bool = False,
) -> Self:
"""
Fit the recommender model on the given DataFrame of interactions.
:param train_data (pd.DataFrame): The DataFrame containing interactions with columns (user, item, tstamp, rating).
:param user_tags (Optional[Dict[Any, List[str]]): Dictionary mapping user IDs to user tags.
:param item_tags (Optional[Dict[Any, List[str]]): Dictionary mapping item IDs to item tags.
:param batch_size (int): The number of interactions per mini-batch. Defaults to 1000.
:param update_interaction (bool): Whether to update existing interactions. Defaults to False.
:param parallel (bool): Whether to run the fitting process in parallel. Defaults to False.
"""
start_time = time.time()

# If train_data contains user_tags and item_tags columns, add them to the model
if "user_tags" in train_data.columns:
user_tags = train_data[["user", "user_tags"]]
for user, tags in user_tags.itertuples(index=False, name=None):
# register user and item features
if user_tags:
for user, tags in tqdm(user_tags.items(), desc="Register user features"):
self.model.register_user_feature(user, tags)
if "item_tags" in train_data.columns:
item_tags = train_data[["item", "item_tags"]]
for item, tags in item_tags.itertuples(index=False, name=None):
if item_tags:
for item, tags in tqdm(item_tags.items(), desc="Register item features"):
self.model.register_item_feature(item, tags)

# Add interactions to the model
Expand All @@ -65,24 +67,32 @@ def fit(
print(f"Throughput: {len(train_data) / (end_time - start_time):.2f} samples/sec")
return self

def bulk_fit(self, train_data: pd.DataFrame, batch_size: int = 1_000, update_interaction: bool=False, parallel: bool=True) -> Self:
def bulk_fit(
self,
train_data: pd.DataFrame,
user_tags: Optional[Dict[Any, List[str]]] = None,
item_tags: Optional[Dict[Any, List[str]]] = None,
batch_size: int = 1_000,
update_interaction: bool=False,
parallel: bool=True
) -> Self:
"""
Fit the recommender model on the given DataFrame of interactions in a single batch.
:param train_data (pd.DataFrame): The DataFrame containing interactions with columns (user, item, tstamp, rating).
:param user_tags (Optional[Dict[Any, List[str]]): Dictionary mapping user IDs to user tags.
:param item_tags (Optional[Dict[Any, List[str]]): Dictionary mapping item IDs to item tags.
:param batch_size (int): The number of interactions per mini-batch. Defaults to 1000.
:param update_interaction (bool): Whether to update existing interactions. Defaults to False.
:param parallel (bool): Whether to run the fitting process in parallel. Defaults to True.
"""
start_time = time.time()

# If train_data contains user_tags and item_tags columns, add them to the model
if "user_tags" in train_data.columns:
user_tags = train_data[["user", "user_tags"]]
for user, tags in user_tags.itertuples(index=False, name=None):
# register user and item features
if user_tags:
for user, tags in tqdm(user_tags.items(), desc="Register user features"):
self.model.register_user_feature(user, tags)
if "item_tags" in train_data.columns:
item_tags = train_data[["item", "item_tags"]]
for item, tags in item_tags.itertuples(index=False, name=None):
if item_tags:
for item, tags in tqdm(item_tags.items(), desc="Register item features"):
self.model.register_item_feature(item, tags)

# Add interactions to the model
Expand Down
2 changes: 1 addition & 1 deletion rtrec/utils/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,6 @@ def shape(self) -> tuple[int, int]:
Retrieves the shape of the interaction matrix.
Returns:
tuple[int, int]: The shape of the interaction matrix.
tuple[int, int]: The shape of the interaction matrix of the form (n_users, n_items).
"""
return self.max_user_id + 1, self.max_item_id + 1

0 comments on commit a61973e

Please sign in to comment.