From 08be7158685e53607f62543b06d75a8818fe2e81 Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sat, 30 Nov 2024 14:07:28 +0100 Subject: [PATCH 1/5] [patch] implemented leave-k-out split --- recbole/data/dataset/dataset.py | 72 +++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 35fce89c6..67ae9a728 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1729,6 +1729,74 @@ def leave_one_out(self, group_by, leave_one_mode): next_ds = [self.copy(_) for _ in next_df] return next_ds + def _split_index_by_leave_k_out(self, grouped_index, leave_k_num, k): + """Split indexes by strategy leave one out. + + Args: + grouped_index (list of list of int): Index to be split. + leave_k_num (int): Number of parts whose length is expected to be ``1``. + + Returns: + list: List of index that has been split. + """ + #print(list(grouped_index)[0]) + next_index = [[] for _ in range(leave_k_num + 1)] + for index in grouped_index: + index = list(index) + tot_cnt = len(index) + legal_leave_k_num = min(leave_k_num, tot_cnt - 1) + pr = tot_cnt - k + next_index[0].extend(index[:pr]) + for i in range(legal_leave_k_num): + next_index[-legal_leave_k_num + i].extend(index[pr:]) + pr += 1 + #print(next_index[0][:len(list(grouped_index)[0])]) + return next_index + + def leave_k_out(self, group_by, leave_k_mode, k): + """Split interaction records by leave k out strategy. + + Args: + group_by (str): Field name that interaction records should grouped by before splitting. + leave_k_mode (str): The way to leave one out. It can only take three values: + 'valid_and_test', 'valid_only' and 'test_only'. + + Returns: + list: List of :class:`~Dataset`, whose interaction features has been split. + """ + self.logger.debug( + f"leave k out, group_by=[{group_by}], leave_k_mode=[{leave_k_mode}]" + ) + if group_by is None: + raise ValueError("leave one out strategy require a group field") + + grouped_inter_feat_index = self._grouped_index( + self.inter_feat[group_by].numpy() + ) + if leave_k_mode == "valid_and_test": + next_index = self._split_index_by_leave_k_out( + grouped_inter_feat_index, leave_k_num=2, k=k + ) + elif leave_k_mode == "valid_only": + next_index = self._split_index_by_leave_k_out( + grouped_inter_feat_index, leave_k_num=1, k=k + ) + next_index.append([]) + elif leave_k_mode == "test_only": + next_index = self._split_index_by_leave_k_out( + grouped_inter_feat_index, leave_k_num=1, k=k + ) + next_index = [next_index[0], [], next_index[1]] + else: + raise NotImplementedError( + f"The leave_one_mode [{leave_k_mode}] has not been implemented." + ) + + self._drop_unused_col() + next_df = [self.inter_feat[index] for index in next_index] + next_ds = [self.copy(_) for _ in next_df] + return next_ds + def shuffle(self): """Shuffle the interaction records inplace.""" self.inter_feat.shuffle() @@ -1799,6 +1867,10 @@ def build(self): datasets = self.leave_one_out( group_by=self.uid_field, leave_one_mode=split_args["LS"] ) + elif split_mode == "LK": + datasets = self.leave_k_out( + group_by=self.uid_field, leave_k_mode=split_args["LK"][0], k=split_args["LK"][1] + ) else: raise NotImplementedError( f"The splitting_method [{split_mode}] has not been implemented." From c0b7908be5ac52bddf29d2bfdee0b3175eff38cd Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sat, 30 Nov 2024 14:16:40 +0100 Subject: [PATCH 2/5] docu for leave-k-out --- docs/source/user_guide/config/evaluation_settings.rst | 2 +- docs/source/user_guide/train_eval_intro.rst | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/user_guide/config/evaluation_settings.rst b/docs/source/user_guide/config/evaluation_settings.rst index a3e7b16c9..17263dc53 100644 --- a/docs/source/user_guide/config/evaluation_settings.rst +++ b/docs/source/user_guide/config/evaluation_settings.rst @@ -12,7 +12,7 @@ Evaluation settings are designed to set parameters about model evaluation. - ``order (str)``: decides how we sort the data in `.inter`. Now we support two kinds of ordering strategies: ``['RO', 'TO']``, which denotes the random ordering and temporal ordering. For ``RO``, we will shuffle the data and then split them in this order. For ``TO``, we will sort the data by the column of `TIME_FIELD` in ascending order and the split them in this order. The default value is ``RO``. - - ``split (dict)``: decides how we split the data in `.inter`. Now we support two kinds of splitting strategies: ``['RS','LS']``, which denotes the ratio-based data splitting and leave-one-out data splitting. If the key of ``split`` is ``RS``, you need to set the splitting ratio like ``[0.8,0.1,0.1]``, ``[7,2,1]`` or ``[8,0,2]``, which denotes the ratio of training set, validation set and testing set respectively. If the key of ``split`` is ``LS``, now we support three kinds of ``LS`` mode: ``['valid_and_test', 'valid_only', 'test_only']`` and you should choose one mode as the value of ``LS``. The default value of ``split`` is ``{'RS': [0.8,0.1,0.1]}``. + - ``split (dict)``: decides how we split the data in `.inter`. Now we support two kinds of splitting strategies: ``['RS','LS','LK']``, which denotes the ratio-based data splitting, leave-one-out data splitting, and leave-k-out data splitting. If the key of ``split`` is ``RS``, you need to set the splitting ratio like ``[0.8,0.1,0.1]``, ``[7,2,1]`` or ``[8,0,2]``, which denotes the ratio of training set, validation set and testing set respectively. If the key of ``split`` is ``LS`` (or ``LK``), now we support three kinds of ``LS`` (``LK``) mode: ``['valid_and_test', 'valid_only', 'test_only']`` and you should choose one mode as the value of ``LS`` (``LK``). For ``LK``, you also need to set the mode and the number ``k`` by providing a list in the following format: ``['valid_and_test', k]``. The number ``k`` represents the number of elements that will be left out according to the specified mode. The default value of ``split`` is ``{'RS': [0.8,0.1,0.1]}``. - ``mode (str|dict)``: decides the data range when we evaluate the model during ``valid`` and ``test`` phase. Now we support four kinds of evaluation mode: ``['full','unixxx','popxxx','labeled']``. ``full`` , ``unixxx`` and ``popxxx`` are designed for the evaluation on implicit feedback (data without label). For implicit feedback, we regard the items with observed interactions as positive items and those without observed interactions as negative items. ``full`` means evaluating the model on the set of all items. ``unixxx``, for example ``uni100``, means uniformly sample 100 negative items for each positive item in testing set, and evaluate the model on these positive items with their sampled negative items. ``popxxx``, for example ``pop100``, means sample 100 negative items for each positive item in testing set based on item popularity (:obj:`Counter(item)` in `.inter` file), and evaluate the model on these positive items with their sampled negative items. Here the `xxx` must be an integer. For explicit feedback (data with label), you should set the mode as ``labeled`` and we will evaluate the model based on your label. You can use ``valid`` and ``test`` as the dict key to set specific ``mode`` in different phases. The default value is ``full``, which is equivalent to ``{'valid': 'full', 'test': 'full'}``. diff --git a/docs/source/user_guide/train_eval_intro.rst b/docs/source/user_guide/train_eval_intro.rst index 472fc4fc7..a6aef1330 100644 --- a/docs/source/user_guide/train_eval_intro.rst +++ b/docs/source/user_guide/train_eval_intro.rst @@ -42,6 +42,7 @@ items or a sampled-based ranking. RO Random Ordering TO Temporal Ordering LS Leave-one-out Splitting + LK Leave-k-out Splitting RS Ratio-based Splitting full full ranking with all item candidates uniN sample-based ranking: each positive item is paired with N sampled negative items in uniform distribution @@ -54,7 +55,7 @@ The parameters used to control the evaluation method are as follows: including ``split``, ``group_by``, ``order`` and ``mode``. - ``split (dict)``: Control the splitting of dataset and the split ratio. The key is splitting method - and value is the list of split ratio. The range of key is ``[RS,LS]``. Defaults to ``{'RS':[0.8, 0.1, 0.1]}`` + and value is the list of split ratio. The range of key is ``[RS,LS,LK]``. Defaults to ``{'RS':[0.8, 0.1, 0.1]}`` - ``group_by (str)``: Whether to split dataset with the group of user. Range in ``[None, user]`` and defaults to ``user``. - ``order (str)``: Control the ordering of data and affect the splitting of data. From bf0abea7c31e8709c132703b4d7bd162fe9ffd82 Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sat, 30 Nov 2024 16:18:21 +0100 Subject: [PATCH 3/5] [patch] Extended the existing ItemKNN approach to include UserKNN by passing the appropriate knn_method parameter in the config. --- recbole/model/general_recommender/itemknn.py | 40 +++++++++++++------- recbole/properties/model/ItemKNN.yaml | 3 +- tests/model/test_model_auto.py | 8 ++++ 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/recbole/model/general_recommender/itemknn.py b/recbole/model/general_recommender/itemknn.py index 47aef23c3..ed6711730 100644 --- a/recbole/model/general_recommender/itemknn.py +++ b/recbole/model/general_recommender/itemknn.py @@ -20,7 +20,7 @@ class ComputeSimilarity: - def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): + def __init__(self, dataMatrix, topk=100, shrink=0, method='item', normalize=True): r"""Computes the cosine similarity of dataMatrix If it is computed on :math:`URM=|users| \times |items|`, pass the URM. @@ -31,6 +31,7 @@ def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): dataMatrix (scipy.sparse.csr_matrix): The sparse data matrix. topk (int) : The k value in KNN. shrink (int) : hyper-parameter in calculate cosine distance. + method (str) : Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. normalize (bool): If True divide the dot product by the product of the norms. """ @@ -38,17 +39,21 @@ def __init__(self, dataMatrix, topk=100, shrink=0, normalize=True): self.shrink = shrink self.normalize = normalize + self.method = method self.n_rows, self.n_columns = dataMatrix.shape - self.TopK = min(topk, self.n_columns) + + if self.method == 'user': + self.TopK = min(topk, self.n_rows) + else: + self.TopK = min(topk, self.n_columns) self.dataMatrix = dataMatrix.copy() - def compute_similarity(self, method, block_size=100): + def compute_similarity(self, block_size=100): r"""Compute the similarity for the given dataset Args: - method (str) : Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. block_size (int): divide matrix to :math:`n\_rows \div block\_size` to calculate cosine_distance if method is 'user', otherwise, divide matrix to :math:`n\_columns \div block\_size`. @@ -68,10 +73,10 @@ def compute_similarity(self, method, block_size=100): self.dataMatrix = self.dataMatrix.astype(np.float32) # Compute sum of squared values to be used in normalization - if method == "user": + if self.method == "user": sumOfSquared = np.array(self.dataMatrix.power(2).sum(axis=1)).ravel() end_local = self.n_rows - elif method == "item": + elif self.method == "item": sumOfSquared = np.array(self.dataMatrix.power(2).sum(axis=0)).ravel() end_local = self.n_columns else: @@ -86,7 +91,7 @@ def compute_similarity(self, method, block_size=100): this_block_size = end_block - start_block # All data points for a given user or item - if method == "user": + if self.method == "user": data = self.dataMatrix[start_block:end_block, :] else: data = self.dataMatrix[:, start_block:end_block] @@ -94,7 +99,7 @@ def compute_similarity(self, method, block_size=100): # Compute similarities - if method == "user": + if self.method == "user": this_block_weights = self.dataMatrix.dot(data.T) else: this_block_weights = self.dataMatrix.T.dot(data) @@ -134,7 +139,7 @@ def compute_similarity(self, method, block_size=100): numNotZeros = np.sum(notZerosMask) values.extend(this_line_weights[top_k_idx][notZerosMask]) - if method == "user": + if self.method == "user": rows.extend(np.ones(numNotZeros) * Index) cols.extend(top_k_idx[notZerosMask]) else: @@ -144,7 +149,7 @@ def compute_similarity(self, method, block_size=100): start_block += block_size # End while - if method == "user": + if self.method == "user": W_sparse = sp.csr_matrix( (values, (rows, cols)), shape=(self.n_rows, self.n_rows), @@ -160,7 +165,9 @@ def compute_similarity(self, method, block_size=100): class ItemKNN(GeneralRecommender): - r"""ItemKNN is a basic model that compute item similarity with the interaction matrix.""" + r"""ItemKNN is a basic model that compute item similarity with the interaction matrix. + Adjusting the value of 'knn_method' in the config file sets the method to either ItemKNN or UserKNN, respectively. + """ input_type = InputType.POINTWISE type = ModelType.TRADITIONAL @@ -170,15 +177,20 @@ def __init__(self, config, dataset): # load parameters info self.k = config["k"] + self.method = config["knn_method"] self.shrink = config["shrink"] if "shrink" in config else 0.0 self.interaction_matrix = dataset.inter_matrix(form="csr").astype(np.float32) shape = self.interaction_matrix.shape assert self.n_users == shape[0] and self.n_items == shape[1] _, self.w = ComputeSimilarity( - self.interaction_matrix, topk=self.k, shrink=self.shrink - ).compute_similarity("item") - self.pred_mat = self.interaction_matrix.dot(self.w).tolil() + self.interaction_matrix, topk=self.k, shrink=self.shrink, method=self.method + ).compute_similarity() + + if self.method == "user": + self.pred_mat = self.w.dot(self.interaction_matrix).tolil() + else: + self.pred_mat = self.interaction_matrix.dot(self.w).tolil() self.fake_loss = torch.nn.Parameter(torch.zeros(1)) self.other_parameter_name = ["w", "pred_mat"] diff --git a/recbole/properties/model/ItemKNN.yaml b/recbole/properties/model/ItemKNN.yaml index 9ce30000c..155f2f835 100644 --- a/recbole/properties/model/ItemKNN.yaml +++ b/recbole/properties/model/ItemKNN.yaml @@ -1,2 +1,3 @@ k: 100 # (int) The neighborhood size. -shrink: 0.0 # (float) A normalization parameter to calculate cosine distance. \ No newline at end of file +shrink: 0.0 # (float) A normalization parameter to calculate cosine distance. +knn_method: 'item' # (string) The method to calculate the similarity matrix ['item','user'] \ No newline at end of file diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 9c18b56c6..965dc728e 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -39,6 +39,14 @@ def test_random(self): def test_itemknn(self): config_dict = { "model": "ItemKNN", + "knn_method": "item" + } + quick_test(config_dict) + + def test_userknn(self): + config_dict = { + "model": "ItemKNN", + "knn_method": "user" } quick_test(config_dict) From d72d4cbc7f317616178463e6996cd0b675ae48c4 Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sun, 1 Dec 2024 11:08:34 +0100 Subject: [PATCH 4/5] Implemented AsymKNN as in Aiolli (2013), https://dl.acm.org/doi/pdf/10.1145/2507157.2507189 --- recbole/model/general_recommender/asymknn.py | 225 +++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 recbole/model/general_recommender/asymknn.py diff --git a/recbole/model/general_recommender/asymknn.py b/recbole/model/general_recommender/asymknn.py new file mode 100644 index 000000000..3a0820e49 --- /dev/null +++ b/recbole/model/general_recommender/asymknn.py @@ -0,0 +1,225 @@ +import numpy as np +import scipy.sparse as sp +import torch +from recbole.model.abstract_recommender import GeneralRecommender +from recbole.utils import InputType, ModelType + +class ComputeSimilarity: + def __init__(self, dataMatrix, topk=100, alpha=0.5, method='item'): + r"""Computes the asymmetric cosine similarity of dataMatrix with alpha parameter. + + Args: + dataMatrix (scipy.sparse.csr_matrix): The sparse data matrix. + topk (int) : The k value in KNN. + alpha (float): Asymmetry control parameter in cosine similarity calculation. + method (str) : Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. + """ + + super(ComputeSimilarity, self).__init__() + + self.method = method + self.alpha = alpha + + self.n_rows, self.n_columns = dataMatrix.shape + + if self.method == 'user': + self.TopK = min(topk, self.n_rows) + else: + self.TopK = min(topk, self.n_columns) + + self.dataMatrix = dataMatrix.copy() + + def compute_asym_similarity(self, block_size=100): + r"""Compute the asymmetric cosine similarity for the given dataset. + + Args: + block_size (int): Divide matrix into blocks for efficient calculation. + + Returns: + list: The similar nodes, if method is 'user', the shape is [number of users, neigh_num], + else, the shape is [number of items, neigh_num]. + scipy.sparse.csr_matrix: sparse matrix W, if method is 'user', the shape is [self.n_rows, self.n_rows], + else, the shape is [self.n_columns, self.n_columns]. + """ + + values = [] + rows = [] + cols = [] + neigh = [] + + self.dataMatrix = self.dataMatrix.astype(np.float32) + + if self.method == "user": + sumOfMatrix = np.array(self.dataMatrix.sum(axis=1)).ravel() + end_local = self.n_rows + elif self.method == "item": + sumOfMatrix = np.array(self.dataMatrix.sum(axis=0)).ravel() + end_local = self.n_columns + else: + raise NotImplementedError("Make sure 'method' is in ['user', 'item']!") + + start_block = 0 + + # Compute all similarities using vectorization + while start_block < end_local: + end_block = min(start_block + block_size, end_local) + this_block_size = end_block - start_block + + # All data points for a given user or item + if self.method == "user": + data = self.dataMatrix[start_block:end_block, :] + else: + data = self.dataMatrix[:, start_block:end_block] + data = data.toarray() + + # Compute similarities + if self.method == "user": + this_block_weights = self.dataMatrix.dot(data.T) + else: + this_block_weights = self.dataMatrix.T.dot(data) + + for index_in_block in range(this_block_size): + this_line_weights = this_block_weights[:, index_in_block] + + Index = index_in_block + start_block + this_line_weights[Index] = 0.0 + + # Apply asymmetric cosine normalization + denominator = ( + (sumOfMatrix[Index] ** self.alpha) * + (sumOfMatrix ** (1 - self.alpha)) + 1e-6 + ) + this_line_weights = np.multiply(this_line_weights, 1 / denominator) + + # Sort indices and select TopK + relevant_partition = (-this_line_weights).argpartition(self.TopK - 1)[0:self.TopK] + relevant_partition_sorting = np.argsort(-this_line_weights[relevant_partition]) + top_k_idx = relevant_partition[relevant_partition_sorting] + neigh.append(top_k_idx) + + # Incrementally build sparse matrix, do not add zeros + notZerosMask = this_line_weights[top_k_idx] != 0.0 + numNotZeros = np.sum(notZerosMask) + + values.extend(this_line_weights[top_k_idx][notZerosMask]) + if self.method == "user": + rows.extend(np.ones(numNotZeros) * Index) + cols.extend(top_k_idx[notZerosMask]) + else: + rows.extend(top_k_idx[notZerosMask]) + cols.extend(np.ones(numNotZeros) * Index) + + start_block += block_size + + # End while + if self.method == "user": + W_sparse = sp.csr_matrix( + (values, (rows, cols)), + shape=(self.n_rows, self.n_rows), + dtype=np.float32, + ) + else: + W_sparse = sp.csr_matrix( + (values, (rows, cols)), + shape=(self.n_columns, self.n_columns), + dtype=np.float32, + ) + return neigh, W_sparse.tocsc() + + +class AsymKNN(GeneralRecommender): + r"""AsymKNN: A traditional recommender model based on asymmetric cosine similarity and score prediction. + + AsymKNN computes user-item recommendations by leveraging asymmetric cosine similarity + over the interaction matrix. This model allows for flexible adjustment of similarity + calculations and scoring normalization via several tunable parameters. + + Config: + k (int): Number of neighbors to consider in the similarity calculation. + method (str): Specifies whether to calculate similarities based on users or items. + Valid options are 'user' or 'item'. + alpha (float): Weight parameter for asymmetric cosine similarity, controlling + the importance of the interaction matrix in the similarity computation. + Must be in the range [0, 1]. + q (int): Exponent for adjusting the 'locality of scoring function' after similarity computation. + beta (float): Parameter for controlling the balance between factors in the + final score normalization. Must be in the range [0, 1]. + + Reference: + Aiolli,F et al. Efficient top-n recommendation for very large scale binary rated datasets. + In Proceedings of the 7th ACM conference on Recommender systems (pp. 273-280). ACM. + """ + + input_type = InputType.POINTWISE + type = ModelType.TRADITIONAL + + def __init__(self, config, dataset): + super(AsymKNN, self).__init__(config, dataset) + + # load parameters info + self.k = config["k"] # Size of neighborhood for cosine + self.method = config["knn_method"] # Caculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. + self.alpha = config['alpha'] if 'alpha' in config else 0.5 # Asymmetric cosine parameter + self.q = config['q'] if 'q' in config else 1.0 # Weight adjustment exponent + self.beta = config['beta'] if 'beta' in config else 0.5 # Beta for final score normalization + + assert 0 <= self.alpha <= 1, f"The asymmetric parameter 'alpha' must be value between in [0,1], but got {self.alpha}" + assert 0 <= self.beta <= 1, f"The asymmetric parameter 'beta' must be value between [0,1], but got {self.beta}" + assert isinstance(self.k, int), f"The neighborhood parameter 'k' must be an integer, but got {self.k}" + assert isinstance(self.q, int), f"The exponent parameter 'q' must be an integer, but got {self.q}" + + self.interaction_matrix = dataset.inter_matrix(form="csr").astype(np.float32) + shape = self.interaction_matrix.shape + assert self.n_users == shape[0] and self.n_items == shape[1] + _, self.w = ComputeSimilarity( + self.interaction_matrix, topk=self.k, alpha=self.alpha, method=self.method + ).compute_asym_similarity() + + if self.method == "user": + nominator = self.w.dot(self.interaction_matrix) + factor1 = np.power(np.sqrt(self.w.power(2).sum(axis=1)),2*self.beta) + factor2 = np.power(np.sqrt(self.interaction_matrix.power(2).sum(axis=0)),2*(1-self.beta)) + denominator = factor1.dot(factor2) + 1e-6 + else: + nominator = self.interaction_matrix.dot(self.w) + factor1 = np.power(np.sqrt(self.interaction_matrix.power(2).sum(axis=1)),2*self.beta) + factor2 = np.power(np.sqrt(self.w.power(2).sum(axis=1)),2*(1-self.beta)) + denominator = factor1.dot(factor2.T) + 1e-6 + + self.pred_mat = (nominator / denominator).tolil() + + # Apply 'locality of scoring function' via q: f(w) = w^q + self.pred_mat = self.pred_mat.power(self.q) + + self.fake_loss = torch.nn.Parameter(torch.zeros(1)) + self.other_parameter_name = ["w", "pred_mat"] + + def forward(self, user, item): + pass + + def calculate_loss(self, interaction): + return torch.nn.Parameter(torch.zeros(1)) + + def predict(self, interaction): + user = interaction[self.USER_ID] + item = interaction[self.ITEM_ID] + user = user.cpu().numpy().astype(int) + item = item.cpu().numpy().astype(int) + result = [] + + for index in range(len(user)): + uid = user[index] + iid = item[index] + score = self.pred_mat[uid, iid] + result.append(score) + result = torch.from_numpy(np.array(result)).to(self.device) + return result + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + user = user.cpu().numpy() + + score = self.pred_mat[user, :].toarray().flatten() + result = torch.from_numpy(score).to(self.device) + + return result \ No newline at end of file From deb0972b64a8c45608188f6ba11d5dfa6698580b Mon Sep 17 00:00:00 2001 From: mkhe93 Date: Sun, 1 Dec 2024 11:09:26 +0100 Subject: [PATCH 5/5] Documentation and tests for AsymKNN --- asset/model_list.json | 14 +++ ...bole.model.general_recommender.asymknn.rst | 4 + .../recbole.model.general_recommender.rst | 1 + .../user_guide/model/general/asymknn.rst | 88 +++++++++++++++++++ docs/source/user_guide/model_intro.rst | 1 + recbole/model/general_recommender/__init__.py | 1 + recbole/properties/model/AsymKNN.yaml | 5 ++ tests/model/test_model_auto.py | 14 +++ 8 files changed, 128 insertions(+) create mode 100644 docs/source/recbole/recbole.model.general_recommender.asymknn.rst create mode 100644 docs/source/user_guide/model/general/asymknn.rst create mode 100644 recbole/properties/model/AsymKNN.yaml diff --git a/asset/model_list.json b/asset/model_list.json index b28b66ded..a2ba8cf85 100644 --- a/asset/model_list.json +++ b/asset/model_list.json @@ -154,6 +154,20 @@ "repository": "RecBole", "repo_link": "https://github.com/RUCAIBox/RecBole" }, + { + "category": "General Recommendation", + "cate_link": "/docs/user_guide/model_intro.html#general-recommendation", + "year": "2013", + "pub": "RecSys'13", + "model": "AsymKNN", + "model_link": "/docs/user_guide/model/general/asymknn.html", + "paper": "Efficient Top-N Recommendation for Very Large Scale Binary Rated Datasets", + "paper_link": "https://doi.org/10.1145/2507157.2507189", + "authors": "Fabio Aiolli", + "ref_code": "", + "repository": "RecBole", + "repo_link": "https://github.com/RUCAIBox/RecBole" + }, { "category": "General Recommendation", "cate_link": "/docs/user_guide/model_intro.html#general-recommendation", diff --git a/docs/source/recbole/recbole.model.general_recommender.asymknn.rst b/docs/source/recbole/recbole.model.general_recommender.asymknn.rst new file mode 100644 index 000000000..55f1cbce0 --- /dev/null +++ b/docs/source/recbole/recbole.model.general_recommender.asymknn.rst @@ -0,0 +1,4 @@ +.. automodule:: recbole.model.general_recommender.asymknn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/recbole/recbole.model.general_recommender.rst b/docs/source/recbole/recbole.model.general_recommender.rst index 9436371ad..2089ced97 100644 --- a/docs/source/recbole/recbole.model.general_recommender.rst +++ b/docs/source/recbole/recbole.model.general_recommender.rst @@ -4,6 +4,7 @@ recbole.model.general\_recommender .. toctree:: :maxdepth: 4 + recbole.model.general_recommender.asymknn recbole.model.general_recommender.admmslim recbole.model.general_recommender.bpr recbole.model.general_recommender.cdae diff --git a/docs/source/user_guide/model/general/asymknn.rst b/docs/source/user_guide/model/general/asymknn.rst new file mode 100644 index 000000000..97844143e --- /dev/null +++ b/docs/source/user_guide/model/general/asymknn.rst @@ -0,0 +1,88 @@ +AsymKNN +=========== + +Introduction +--------------------- + +`[paper] `_ + +**Title:** Efficient Top-N Recommendation for Very Large Scale Binary Rated Datasets + +**Authors:** Fabio Aiolli + +**Abstract:** We present a simple and scalable algorithm for top-N recommendation able to deal with very large datasets and (binary rated) implicit feedback. We focus on memory-based collaborative filtering +algorithms similar to the well known neighboor based technique for explicit feedback. The major difference, that makes the algorithm particularly scalable, is that it uses positive feedback only +and no explicit computation of the complete (user-by-user or itemby-item) similarity matrix needs to be performed. +The study of the proposed algorithm has been conducted on data from the Million Songs Dataset (MSD) challenge whose task was to suggest a set of songs (out of more than 380k available songs) to more than 100k users given half of the user listening history and +complete listening history of other 1 million people. +In particular, we investigate on the entire recommendation pipeline, starting from the definition of suitable similarity and scoring functions and suggestions on how to aggregate multiple ranking strategies to define the overall recommendation. The technique we are +proposing extends and improves the one that already won the MSD challenge last year. + +In this article, we introduce a versatile class of recommendation algorithms that calculate either user-to-user or item-to-item similarities as the foundation for generating recommendations. This approach enables the flexibility to switch between UserKNN and ItemKNN models depending on the desired application. + +A distinguishing feature of this class of algorithms, exemplified by AsymKNN, is its use of asymmetric cosine similarity, which generalizes the traditional cosine similarity. Specifically, when the asymmetry parameter +``alpha = 0.5``, the method reduces to the standard cosine similarity, while other values of ``alpha`` allow for tailored emphasis on specific aspects of the interaction data. Furthermore, setting the parameter +``beta = 1.0`` ensures a traditional UserKNN or ItemKNN, as the final scores are only divided by a fixed positive constant, preserving the same order of recommendations. + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``k (int)`` : The neighborhood size. Defaults to ``100``. + +- ``alpha (float)`` : Weight parameter for asymmetric cosine similarity. Defaults to ``0.5``. + +- ``beta (float)`` : Parameter for controlling the balance between factors in the final score normalization. Defaults to ``1.0``. + +- ``q (int)`` : The 'locality of scoring function' parameter. Defaults to ``1``. + +**Additional Parameters:** + +- ``knn_method (str)`` : Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items.. Defaults to ``item``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='AsymKNN', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + k choice [10,50,100,200,250,300,400,500,1000,1500,2000,2500] + alpha choice [0.0,0.2,0.5,0.8,1.0] + beta choice [0.0,0.2,0.5,0.8,1.0] + q choice [1,2,3,4,5,6] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/train_eval_intro` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index 8b4c59d78..7de3f59e6 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -13,6 +13,7 @@ task of top-n recommendation. All the collaborative filter(CF) based models are .. toctree:: :maxdepth: 1 + model/general/asymknn model/general/pop model/general/itemknn model/general/bpr diff --git a/recbole/model/general_recommender/__init__.py b/recbole/model/general_recommender/__init__.py index e71f2b4ec..d5ec68e23 100644 --- a/recbole/model/general_recommender/__init__.py +++ b/recbole/model/general_recommender/__init__.py @@ -1,3 +1,4 @@ +from recbole.model.general_recommender.asymknn import AsymKNN from recbole.model.general_recommender.bpr import BPR from recbole.model.general_recommender.cdae import CDAE from recbole.model.general_recommender.convncf import ConvNCF diff --git a/recbole/properties/model/AsymKNN.yaml b/recbole/properties/model/AsymKNN.yaml new file mode 100644 index 000000000..f711d6860 --- /dev/null +++ b/recbole/properties/model/AsymKNN.yaml @@ -0,0 +1,5 @@ +k: 100 # Number of neighbors to consider in the similarity calculation. +q: 1 # Exponent for adjusting the 'locality of scoring function' after similarity computation. +beta: 1.0 # Parameter for controlling the balance between factors in the final score normalization. +alpha: 0.5 # Weight parameter for asymmetric cosine similarity +knn_method: 'item' # Calculate the similarity of users if method is 'user', otherwise, calculate the similarity of items. \ No newline at end of file diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 965dc728e..fe809ee32 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -50,6 +50,20 @@ def test_userknn(self): } quick_test(config_dict) + def test_asymitemknn(self): + config_dict = { + "model": "AsymKNN", + "knn_method": "item" + } + quick_test(config_dict) + + def test_asymuserknn(self): + config_dict = { + "model": "AsymKNN", + "knn_method": "user" + } + quick_test(config_dict) + def test_bpr(self): config_dict = { "model": "BPR",