diff --git a/CoRML Demonstration.ipynb b/CoRML Demonstration.ipynb new file mode 100644 index 0000000..5c49b47 --- /dev/null +++ b/CoRML Demonstration.ipynb @@ -0,0 +1,597 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1cN-BpEGaSf5" + }, + "source": [ + "# CoRML\n", + "------------\n", + "This is an demonstration of our paper:\n", + "\n", + "> Tianjun Wei, Jianghong Ma, Tommy W.S. Chow. Collaborative Residual Metric Learning. In SIGIR 2023. [[arxiv](https://arxiv.org/abs/2304.07971)] [[Github](https://github.com/Joinn99/CoRML)]\n", + "\n", + "

\n", + " \n", + "

" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8pxSpTa9ZoTy" + }, + "source": [ + "## Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "cellView": "form", + "id": "FrbhQmJjGF3I" + }, + "outputs": [], + "source": [ + "#@title Environment Setting\n", + "#@markdown CoRML is implemented based on the recommendation toolbox [RecBole](https://recbole.io/).\n", + "%%capture\n", + "! pip install recbole==1.1.1" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "cellView": "form", + "id": "pqa6o51yExY_" + }, + "outputs": [], + "source": [ + "#@title Download Dataset\n", + "#@markdown Here we use four public datasets:
\n", + "#@markdown - [ML-20M](http://files.grouplens.org/datasets/movielens/)\n", + "#@markdown - [Pinterest](https://github.com/hexiangnan/neural_collaborative_filtering/blob/master/Data/)\n", + "#@markdown - [Gowalla](https://github.com/kuandeng/LightGCN/blob/master/Data/gowalla/)\n", + "#@markdown - [Yelp2018](https://github.com/kuandeng/LightGCN/blob/master/Data/yelp2018)\n", + "\n", + "#@markdown We also provide processed datasets that conform to the RecBole data format, which can be downloaded [here](https://github.com/Joinn99/CoRML/tree/torch/Data).\n", + "\n", + "%%capture\n", + "! git clone https://github.com/Joinn99/CoRML\n", + "! cd CoRML && unzip -o \"Data/*.zip\" -d /content" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "7NnQ5JnKb-p1" + }, + "outputs": [], + "source": [ + "#@title Configuation\n", + "#@markdown Here we list the configurations of the experiment.\n", + "#@markdown For a detailed description of the configuration, please refer to [RecBole API](https://recbole.io/docs/user_guide/config_settings.html).\n", + "\n", + "CONFIG = {\n", + " # General Parameters\n", + " \"seed\": 2026, #Choose in [2022,2024,2026,2028,2030]\n", + " \"reproducibility\": True,\n", + "\n", + " # Training Parameters\n", + " \"train_neg_sample_args\":{\n", + " \"sample_num\": 1\n", + " },\n", + " \"epochs\": 1,\n", + " \"train_batch_size\": 524288,\n", + "\n", + "\n", + " # Evaluation Parameters\n", + " \"eval_args\":{\n", + " \"split\": {'RS':[0.6, 0.2, 0.2]},\n", + " \"order\": \"RO\",\n", + " \"group_by\": \"user\",\n", + " \"mode\": \"full\"\n", + " },\n", + " \"eval_batch_size\": 1048576,\n", + " \"eval_step\": 0,\n", + " \"stopping_step\": 1,\n", + " \"valid_metric\": \"NDCG@20\",\n", + " \"topk\": [5, 10, 20],\n", + " \"metrics\": [\"NDCG\", \"MRR\"] ,\n", + " \"metric_decimal_place\": 6,\n", + "\n", + " # Dataset Default Parameters\n", + " \"data_path\": \"Data\",\n", + " \"load_col\":{\n", + " \"inter\": [\"user_id\", \"item_id\"],\n", + " \"user\": [\"user_id\"],\n", + " \"item\": [\"item_id\"],\n", + " },\n", + "\n", + " \"USER_ID_FIELD\": \"user_id\",\n", + " \"ITEM_ID_FIELD\": \"item_id\",\n", + " \"filter_inter_by_user_or_item\": False,\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f5ytkUIsbn_B" + }, + "source": [ + "### Hyperparameters\n", + "Here we list the hyperparameter setting of FPSR. To evaluate the performance of FPSR with different hyperparameters, modify the values in `FPSR_HYPER` and select `Runtime->Run after`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "k5R7vKwlMPfO" + }, + "outputs": [], + "source": [ + "DATASET = {\n", + " \"dataset\": \"yelp2018\"\n", + "}\n", + "\n", + "CoRML_HYPER = {\n", + " # Model Default Parameters\n", + " \"partition_ratio\": 0.3, # Maximum size ratio of item partition (1.0 means no partitioning)\n", + " \"sparse_approx\": True, # Sparse approximation to reduce storage size of H to compare with MF-based models\n", + " \"eigenvectors\": 256, # Number of eigenvectors used in SVD\n", + " \"admm_iter\": 50,\n", + "\n", + " \"lambda\": 0.8, # Weights for H and G in preference scores\n", + " \"dual_step_length\": 5e3, # Dual step length of ADMM\n", + " \"l2_regularization\": 1.0, # L2-regularization for learning weight matrix G\n", + " \"item_degree_norm\": 0.1, # Item degree norm for learning weight matrix G\n", + " \"global_scaling\": -1.0, # Global scaling in approximated ranking weights (in logarithm scale)\n", + " \"user_scaling\": 0.5, # User degree scaling in approximated ranking weights\n", + "}\n", + "\n", + "CONFIG.update(DATASET)\n", + "CONFIG.update(CoRML_HYPER)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7hNBSzOCZszh" + }, + "source": [ + "## Model Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "cellView": "form", + "id": "jHuhZcigM0lJ" + }, + "outputs": [], + "source": [ + "r\"\"\"\n", + "CoRML (PyTorch Version)\n", + "################################################\n", + "Author:\n", + " Tianjun Wei (tjwei2-c@my.cityu.edu.hk)\n", + "Reference:\n", + " Tianjun Wei et al. \"Collaborative Residual Metric Learning.\" in SIGIR 2023.\n", + "Created Date:\n", + " 2023/04/10\n", + "\"\"\"\n", + "import torch\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "from recbole.utils import InputType\n", + "from recbole.utils.enum_type import ModelType\n", + "from recbole.model.abstract_recommender import GeneralRecommender\n", + "\n", + "\n", + "class SpectralInfo(object):\n", + " r\"\"\"A class for producing spectral information of the interaction matrix, including node degrees,\n", + " singular vectors, and partitions produced by spectral graph partitioning.\n", + "\n", + " Reference: https://github.com/Joinn99/FPSR/blob/torch/FPSR/model.py\n", + " \"\"\"\n", + " def __init__(self, inter_mat, config) -> None:\n", + " # load parameters info\n", + " self.eigen_dim = config[\"eigenvectors\"] # Number of eigenvectors used in SVD\n", + " self.t_u = config[\"user_scaling\"]\n", + " self.norm_di = 2 * config[\"item_degree_norm\"]\n", + " self.tau = config[\"partition_ratio\"] # Maximum size ratio of item partition (1.0 means no partitioning)\n", + " self.inter_mat = inter_mat\n", + "\n", + " def _degree(self, inter_mat=None, dim=0, exp=-0.5) -> torch.Tensor:\n", + " r\"\"\"Get the degree of users and items.\n", + " \n", + " Returns:\n", + " Tensor of the node degrees.\n", + " \"\"\"\n", + " if inter_mat is None:\n", + " inter_mat = self.inter_mat\n", + " d_inv = torch.nan_to_num(\n", + " torch.clip(torch.sparse.sum(inter_mat, dim=dim).to_dense(), min=1.).pow(exp), nan=1., posinf=1., neginf=1.\n", + " )\n", + " return d_inv\n", + "\n", + " def _svd(self, mat, k) -> torch.Tensor:\n", + " r\"\"\"Perform Truncated singular value decomposition (SVD) on\n", + " the input matrix, return top-k eigenvectors.\n", + " \n", + " Returns:\n", + " Tok-k eigenvectors.\n", + " \"\"\"\n", + " _, _, V = torch.svd_lowrank(mat, q=max(4*k, 32), niter=10)\n", + " return V[:, :k]\n", + "\n", + " def _norm_adj(self, item_list=None) -> torch.Tensor:\n", + " r\"\"\"Get the normalized item-item adjacency matrix for a group of items.\n", + " \n", + " Returns:\n", + " Sparse tensor of the normalized item-item adjacency matrix.\n", + " \"\"\"\n", + " if item_list is None:\n", + " vals = self.inter_mat.values() * self.di_isqr[self.inter_mat.indices()[1]].squeeze()\n", + " return torch.sparse_coo_tensor(\n", + " self.inter_mat.indices(),\n", + " self._degree(dim=1)[self.inter_mat.indices()[0]] * vals,\n", + " size=self.inter_mat.shape, dtype=torch.float\n", + " ).coalesce()\n", + " else:\n", + " inter = self.inter_mat.index_select(dim=1, index=item_list).coalesce()\n", + " vals = inter.values() * self.di_isqr[item_list][inter.indices()[1]].squeeze()\n", + " return torch.sparse_coo_tensor(\n", + " inter.indices(),\n", + " self._degree(inter, dim=1)[inter.indices()[0]] * vals,\n", + " size=inter.shape, dtype=torch.float\n", + " ).coalesce()\n", + "\n", + " def run(self):\n", + " r\"\"\"\n", + " Spectral information\n", + " \"\"\"\n", + " self.di_isqr = self._degree(dim=0).reshape(-1, 1)\n", + " self.di_sqr = self._degree(dim=0, exp=0.5).reshape(1, -1)\n", + "\n", + " u_norm = self._degree(dim=1, exp=-self.t_u).reshape(-1, 1)\n", + " self.u_norm = u_norm / u_norm.min()\n", + "\n", + " self.V_mat = self._svd(self._norm_adj(), self.eigen_dim)\n", + "\n", + " return self.di_sqr, self.u_norm, self.V_mat\n", + "\n", + " def partitioning(self, V) -> torch.Tensor:\n", + " r\"\"\"\n", + " Graph bipartitioning\n", + " \"\"\"\n", + " split = V[:, 1] >= 0\n", + " if split.sum() == split.shape[0] or split.sum() == 0:\n", + " split = V[:, 1] >= torch.median(V[:, 1])\n", + " return split\n", + "\n", + " def get_partition(self, ilist, total_num):\n", + " r\"\"\"\n", + " Get partitions of item-item graph\n", + " \"\"\"\n", + " if ilist.shape[0] <= total_num * self.tau:\n", + " return [ilist]\n", + " else:\n", + " # If the partition size is larger than size limit,\n", + " # perform graph partitioning on this partition.\n", + " split = self.partitioning(self._svd(self._norm_adj(ilist), 2))\n", + " return self.get_partition(ilist[torch.where(split)[0]], total_num) + \\\n", + " self.get_partition(ilist[torch.where(~split)[0]], total_num)\n", + "\n", + "\n", + "class CoRML(GeneralRecommender):\n", + " r\"\"\"CoRML is an item-based metric learning model for collaborative filtering.\n", + "\n", + " CoRML learn a generalized distance user-item distance metric to capture user\n", + " preference in user-item interaction signals by modeling the residuals of general\n", + " Mahalanobis distance.\n", + " \"\"\"\n", + " input_type = InputType.POINTWISE\n", + " type = ModelType.TRADITIONAL\n", + "\n", + " def __init__(self, config, dataset):\n", + " r\"\"\"\n", + " Model initialization and training.\n", + " \"\"\"\n", + " super().__init__(config, dataset)\n", + "\n", + " self.lambda_ = config[\"lambda\"] # Weights for H and G in preference scores\n", + " self.rho = config[\"dual_step_length\"] # Dual step length of ADMM\n", + " self.theta = config[\"l2_regularization\"] # L2-regularization for learning weight matrix G\n", + " self.norm_di = 2 * config[\"item_degree_norm\"] # Item degree norm for learning weight matrix G\n", + " self.eps = np.power(10, config[\"global_scaling\"]) # Global scaling in approximated ranking weights (in logarithm scale)\n", + "\n", + " self.sparse_approx = config[\"sparse_approx\"] # Sparse approximation to reduce storage size of H\n", + " self.admm_iter = config[\"admm_iter\"] # Number of iterations for ADMM\n", + "\n", + " # Dummy pytorch parameters required by Recbole\n", + " self.dummy_param = torch.nn.Parameter(torch.zeros(1))\n", + "\n", + " # User-item interaction matrix\n", + " self.inter_mat = dataset.inter_matrix(form='coo')\n", + " self.inter_mat = torch.sparse_coo_tensor(\n", + " torch.LongTensor(np.array([self.inter_mat.row, self.inter_mat.col])),\n", + " torch.FloatTensor(self.inter_mat.data),\n", + " size=self.inter_mat.shape, dtype=torch.float\n", + " ).coalesce().to(self.device)\n", + "\n", + " # TRAINING PROCESS\n", + " item_list = self.update_G(config)\n", + " self.update_H(item_list)\n", + "\n", + " def DI(self, pow=1., ilist=None):\n", + " r\"\"\"\n", + " Degree of item node\n", + " \"\"\"\n", + " if ilist is not None:\n", + " return torch.pow(self.di_sqr[:, ilist], pow)\n", + " else:\n", + " return torch.pow(self.di_sqr, pow)\n", + "\n", + " def update_G(self, config):\n", + " r\"\"\"\n", + " Update G matrix\n", + " \"\"\"\n", + " G = SpectralInfo(self.inter_mat, config)\n", + " self.di_sqr, self.u_norm, self.V_mat = G.run()\n", + " item_list = G.get_partition(\n", + " torch.arange(self.n_items, device=self.device), self.n_items\n", + " )\n", + " return item_list\n", + "\n", + " def update_H(self, item_list):\n", + " r\"\"\"\n", + " Update H matrix\n", + " \"\"\"\n", + " self.H_indices = []\n", + " self.H_values = []\n", + "\n", + " for ilist in tqdm(item_list, desc=\"Partition\", bar_format=\"{elapsed}\"):\n", + " H_triu = self.update_H_part(ilist)\n", + " H_triu = torch.where(H_triu >= 5e-4, H_triu, 0).to_sparse_coo()\n", + " self.H_indices.append(ilist[H_triu.indices()])\n", + " self.H_values.append(H_triu.values())\n", + " \n", + " H_mat = torch.sparse_coo_tensor(indices=torch.cat(self.H_indices, dim=1),\n", + " values=torch.cat(self.H_values, dim=0),\n", + " size=(self.n_items, self.n_items)).coalesce()\n", + " del self.H_indices, self.H_values\n", + " \n", + " # Sparse approximation\n", + " if self.sparse_approx:\n", + " limit = (self.n_users + self.n_items) * 64 # Embedding size in MF models\n", + " thres = 1e-4\n", + " while (H_mat._nnz() + H_mat.indices().shape[-1] + self.n_items + 1) >= limit:\n", + " mask = torch.where(H_mat.values() > thres)[0]\n", + " H_mat = torch.sparse_coo_tensor(indices=H_mat.indices().index_select(-1, mask),\n", + " values=H_mat.values().index_select(-1, mask),\n", + " size=(self.n_items, self.n_items)).coalesce()\n", + " thres *= 1.25\n", + " self.H_mat = H_mat.T.to_sparse_csr()\n", + "\n", + " def _inner_prod(self, A_mat: torch.Tensor, B_mat: torch.Tensor):\n", + " r\"\"\"\n", + " Small-batch inner product\n", + " \"\"\"\n", + " assert A_mat.shape[-2] == B_mat.shape[-2]\n", + " result = torch.zeros((A_mat.shape[-1], B_mat.shape[-1]), device=self.device)\n", + " for chunk in torch.split(torch.arange(0, A_mat.shape[-2], device=self.device), 10000):\n", + " result += A_mat.index_select(dim=-2, index=chunk).to_dense().T @ \\\n", + " B_mat.index_select(dim=-2, index=chunk).to_dense()\n", + " return result\n", + "\n", + " def update_H_part(self, ilist):\n", + " r\"\"\"\n", + " Learning H in each partition (if any)\n", + " \"\"\"\n", + " R_mat = self.inter_mat.index_select(dim=1, index=ilist) * self.DI(-self.norm_di, ilist)\n", + "\n", + " H_aux = (0.5 / self.lambda_) * self._inner_prod(R_mat, R_mat)\n", + " II_mat = self._inner_prod(R_mat, self.u_norm * R_mat)\n", + " del R_mat\n", + "\n", + " V_mat = self.V_mat[ilist, :]\n", + " diag_vvt = torch.square(V_mat).sum(dim=1).view(-1)\n", + "\n", + " G_mat = - self.DI(self.norm_di - 1, ilist).T * \\\n", + " (V_mat @ V_mat.T).clip(0).fill_diagonal_(0) * self.DI(1 - self.norm_di, ilist)\n", + " del V_mat, diag_vvt\n", + "\n", + " H_aux = H_aux + (\n", + " self.eps * ((1 / self.lambda_) - 1) *\n", + " (II_mat @ G_mat)\n", + " )\n", + " del G_mat\n", + "\n", + " II_inv = torch.inverse(\n", + " self.eps * II_mat + torch.diag(\n", + " self.DI(2, ilist).view(-1) * self.theta + self.rho\n", + " )\n", + " )\n", + " del II_mat\n", + "\n", + " H_aux = II_inv @ H_aux\n", + " Phi_mat = torch.zeros_like(H_aux, device=self.device)\n", + " S_mat = torch.zeros_like(H_aux, device=self.device)\n", + "\n", + " for _ in range(self.admm_iter):\n", + " # ADMM Iteration\n", + " H_tilde = H_aux + II_inv @ (self.rho * (S_mat - Phi_mat))\n", + " lag_op = torch.diag(H_tilde) / (torch.diag(II_inv) + 1e-10)\n", + " H_mat = H_tilde - II_inv * lag_op # Update H\n", + " S_mat = H_mat + Phi_mat \n", + " S_mat = torch.clip((S_mat.T + S_mat) / 2, min=0) # Update S\n", + " Phi_mat += H_mat - S_mat # Update Phi\n", + "\n", + " return torch.triu(S_mat)\n", + "\n", + " def forward(self):\n", + " r\"\"\"\n", + " Abstract method of GeneralRecommender in RecBole (not used)\n", + " \"\"\"\n", + " pass\n", + "\n", + " def calculate_loss(self, interaction):\n", + " r\"\"\"\n", + " Abstract method of GeneralRecommender in RecBole (not used)\n", + " \"\"\"\n", + " return torch.nn.Parameter(torch.zeros(1))\n", + "\n", + " def predict(self, interaction):\n", + " r\"\"\"\n", + " Abstract method of GeneralRecommender in RecBole (not used)\n", + " \"\"\"\n", + " raise NotImplementedError\n", + "\n", + " def full_sort_predict(self, interaction):\n", + " r\"\"\"\n", + " Recommend items for the input users\n", + " \"\"\"\n", + " R_mat = self.inter_mat.index_select(dim=0, index=interaction[self.USER_ID]).to_dense()\n", + "\n", + " Y_mat = R_mat * self.DI(-1) @ self.V_mat @ self.V_mat.T * self.di_sqr\n", + " Y_mat = ((1 / self.lambda_) - 1) * \\\n", + " torch.clip(Y_mat - R_mat * torch.square(self.V_mat).sum(dim=1).reshape(1, -1), min=0)\n", + "\n", + " R_mat = R_mat * self.DI(-self.norm_di)\n", + " Y_mat += ((self.H_mat @ R_mat.T).T + R_mat @ self.H_mat) * self.DI(self.norm_di)\n", + "\n", + " return Y_mat\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9spbyRBwb0dn" + }, + "source": [ + "## Running" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "htpZ3NxHYlKP", + "outputId": "04d2acdd-b84a-4367-a776-517143fa5513" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==========Preprocessing==========\n", + "Configuration loading complete.\n", + "Data loading complete.\n", + "============Training=============\n", + "(This step may last several minutes)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "01:51\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training complete.\n", + "===========Evaluation============\n", + "Evaluation results:\n", + "NDCG@5 : 0.069239\n", + "NDCG@10 : 0.071014\n", + "NDCG@20 : 0.081999\n", + "MRR@5 : 0.144153\n", + "MRR@10 : 0.158800\n", + "MRR@20 : 0.167767\n" + ] + } + ], + "source": [ + "r\"\"\"\n", + "Code reference: https://recbole.io/docs/developer_guide/customize_models.html\n", + "\"\"\"\n", + "\n", + "import warnings\n", + "from recbole.utils import init_seed\n", + "from recbole.trainer import Trainer\n", + "from recbole.config import Config\n", + "from recbole.data import create_dataset, data_preparation\n", + "\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "print('='* 10 + 'Preprocessing' + '=' * 10)\n", + "config = Config(model=CoRML, config_dict=CONFIG)\n", + "init_seed(config['seed'], config['reproducibility'])\n", + "print(\"Configuration loading complete.\")\n", + "\n", + "# dataset filtering\n", + "dataset = create_dataset(config)\n", + "train_data, valid_data, test_data = data_preparation(config, dataset)\n", + "print(\"Data loading complete.\")\n", + "\n", + "print('='* 12 + 'Training' + '=' * 13)\n", + "print(\"(This step may last several minutes)\")\n", + "# model loading and initialization\n", + "model = CoRML(config, train_data.dataset).to(config['device'])\n", + "\n", + "# trainer loading and initialization\n", + "trainer = Trainer(config, model)\n", + "\n", + "# model training\n", + "_, _ = trainer.fit(train_data, valid_data)\n", + "\n", + "print(\"Training complete.\")\n", + "\n", + "print('='* 11 + 'Evaluation' + '=' * 12)\n", + "# model evaluation\n", + "test_result = trainer.evaluate(test_data)\n", + "print(\"Evaluation results:\")\n", + "for metric, value in test_result.items():\n", + " print('{:10s}: {:.6f}'.format(metric.upper(), value))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "SparseRec", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" + }, + "vscode": { + "interpreter": { + "hash": "933075f399f88ee3a7bc96289e7bedd6322a6307b0e261e2bc2c90799eef1243" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/CoRML/__init__.py b/CoRML/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/CoRML/model.py b/CoRML/model.py new file mode 100644 index 0000000..19e0970 --- /dev/null +++ b/CoRML/model.py @@ -0,0 +1,294 @@ +r""" +CoRML (PyTorch Version) +################################################ +Author: + Tianjun Wei (tjwei2-c@my.cityu.edu.hk) +Reference: + Tianjun Wei et al. "Collaborative Residual Metric Learning." in SIGIR 2023. +Created Date: + 2023/04/10 +""" +import torch +import numpy as np + +from recbole.utils import InputType +from recbole.utils.enum_type import ModelType +from recbole.model.abstract_recommender import GeneralRecommender + + +class SpectralInfo(object): + r"""A class for producing spectral information of the interaction matrix, including node degrees, + singular vectors, and partitions produced by spectral graph partitioning. + + Reference: https://github.com/Joinn99/FPSR/blob/torch/FPSR/model.py + """ + def __init__(self, inter_mat, config) -> None: + # load parameters info + self.eigen_dim = config["eigenvectors"] # Number of eigenvectors used in SVD + self.t_u = config["user_scaling"] + self.norm_di = 2 * config["item_degree_norm"] + self.partition_ratio = config["partition_ratio"] # Maximum size ratio of item partition (1.0 means no partitioning) + self.inter_mat = inter_mat + + def _degree(self, inter_mat=None, dim=0, exp=-0.5) -> torch.Tensor: + r"""Get the degree of users and items. + + Returns: + Tensor of the node degrees. + """ + if inter_mat is None: + inter_mat = self.inter_mat + d_inv = torch.nan_to_num( + torch.clip(torch.sparse.sum(inter_mat, dim=dim).to_dense(), min=1.).pow(exp), nan=1., posinf=1., neginf=1. + ) + return d_inv + + def _svd(self, mat, k) -> torch.Tensor: + r"""Perform Truncated singular value decomposition (SVD) on + the input matrix, return top-k eigenvectors. + + Returns: + Tok-k eigenvectors. + """ + _, _, V = torch.svd_lowrank(mat, q=max(4*k, 32), niter=10) + return V[:, :k] + + def _norm_adj(self, item_list=None) -> torch.Tensor: + r"""Get the normalized item-item adjacency matrix for a group of items. + + Returns: + Sparse tensor of the normalized item-item adjacency matrix. + """ + if item_list is None: + vals = self.inter_mat.values() * self.di_isqr[self.inter_mat.indices()[1]].squeeze() + return torch.sparse_coo_tensor( + self.inter_mat.indices(), + self._degree(dim=1)[self.inter_mat.indices()[0]] * vals, + size=self.inter_mat.shape, dtype=torch.float + ).coalesce() + else: + inter = self.inter_mat.index_select(dim=1, index=item_list).coalesce() + vals = inter.values() * self.di_isqr[item_list][inter.indices()[1]].squeeze() + return torch.sparse_coo_tensor( + inter.indices(), + self._degree(inter, dim=1)[inter.indices()[0]] * vals, + size=inter.shape, dtype=torch.float + ).coalesce() + + def run(self): + r""" + Spectral information + """ + self.di_isqr = self._degree(dim=0).reshape(-1, 1) + self.di_sqr = self._degree(dim=0, exp=0.5).reshape(1, -1) + + u_norm = self._degree(dim=1, exp=-self.t_u).reshape(-1, 1) + self.u_norm = u_norm / u_norm.min() + + self.V_mat = self._svd(self._norm_adj(), self.eigen_dim) + + return self.di_sqr, self.u_norm, self.V_mat + + def partitioning(self, V) -> torch.Tensor: + r""" + Graph bipartitioning + """ + split = V[:, 1] >= 0 + if split.sum() == split.shape[0] or split.sum() == 0: + split = V[:, 1] >= torch.median(V[:, 1]) + return split + + def get_partition(self, ilist, total_num): + r""" + Get partitions of item-item graph + """ + assert self.partition_ratio > (1 / total_num) + + if ilist.shape[0] <= total_num * self.partition_ratio: + return [ilist] + else: + # If the partition size is larger than size limit, + # perform graph partitioning on this partition. + split = self.partitioning(self._svd(self._norm_adj(ilist), 2)) + return self.get_partition(ilist[torch.where(split)[0]], total_num) + \ + self.get_partition(ilist[torch.where(~split)[0]], total_num) + + +class CoRML(GeneralRecommender): + r"""CoRML is an item-based metric learning model for collaborative filtering. + + CoRML learn a generalized distance user-item distance metric to capture user + preference in user-item interaction signals by modeling the residuals of general + Mahalanobis distance. + """ + input_type = InputType.POINTWISE + type = ModelType.TRADITIONAL + + def __init__(self, config, dataset): + r""" + Model initialization and training. + """ + super().__init__(config, dataset) + + self.lambda_ = config["lambda"] # Weights for H and G in preference scores + self.rho = config["dual_step_length"] # Dual step length of ADMM + self.theta = config["l2_regularization"] # L2-regularization for learning weight matrix G + self.norm_di = 2 * config["item_degree_norm"] # Item degree norm for learning weight matrix G + self.eps = np.power(10, config["global_scaling"]) # Global scaling in approximated ranking weights (in logarithm scale) + + self.sparse_approx = config["sparse_approx"] # Sparse approximation to reduce storage size of H + self.admm_iter = config["admm_iter"] # Number of iterations for ADMM + + # Dummy pytorch parameters required by Recbole + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + # User-item interaction matrix + self.inter_mat = dataset.inter_matrix(form='coo') + self.inter_mat = torch.sparse_coo_tensor( + torch.LongTensor(np.array([self.inter_mat.row, self.inter_mat.col])), + torch.FloatTensor(self.inter_mat.data), + size=self.inter_mat.shape, dtype=torch.float + ).coalesce().to(self.device) + + # TRAINING PROCESS + item_list = self.update_G(config) + self.update_H(item_list) + + def DI(self, pow=1., ilist=None): + r""" + Degree of item node + """ + if ilist is not None: + return torch.pow(self.di_sqr[:, ilist], pow) + else: + return torch.pow(self.di_sqr, pow) + + def update_G(self, config): + r""" + Update G matrix + """ + G = SpectralInfo(self.inter_mat, config) + self.di_sqr, self.u_norm, self.V_mat = G.run() + item_list = G.get_partition( + torch.arange(self.n_items, device=self.device), self.n_items + ) + return item_list + + def update_H(self, item_list): + r""" + Update H matrix + """ + self.H_indices = [] + self.H_values = [] + + for ilist in item_list: + H_triu = self.update_H_part(ilist) + H_triu = torch.where(H_triu >= 5e-4, H_triu, 0).to_sparse_coo() + self.H_indices.append(ilist[H_triu.indices()]) + self.H_values.append(H_triu.values()) + + H_mat = torch.sparse_coo_tensor(indices=torch.cat(self.H_indices, dim=1), + values=torch.cat(self.H_values, dim=0), + size=(self.n_items, self.n_items)).coalesce() + del self.H_indices, self.H_values + + # Sparse approximation + if self.sparse_approx: + limit = (self.n_users + self.n_items) * 64 # Embedding size in MF models + thres = 1e-4 + while (H_mat._nnz() + H_mat.indices().shape[-1] + self.n_items + 1) >= limit: + mask = torch.where(H_mat.values() > thres)[0] + H_mat = torch.sparse_coo_tensor(indices=H_mat.indices().index_select(-1, mask), + values=H_mat.values().index_select(-1, mask), + size=(self.n_items, self.n_items)).coalesce() + thres *= 1.25 + self.H_mat = H_mat.T.to_sparse_csr() + + def _inner_prod(self, A_mat: torch.Tensor, B_mat: torch.Tensor): + r""" + Small-batch inner product + """ + assert A_mat.shape[-2] == B_mat.shape[-2] + result = torch.zeros((A_mat.shape[-1], B_mat.shape[-1]), device=self.device) + for chunk in torch.split(torch.arange(0, A_mat.shape[-2], device=self.device), 10000): + result += A_mat.index_select(dim=-2, index=chunk).to_dense().T @ \ + B_mat.index_select(dim=-2, index=chunk).to_dense() + return result + + def update_H_part(self, ilist): + r""" + Learning H in each partition (if any) + """ + R_mat = self.inter_mat.index_select(dim=1, index=ilist) * self.DI(-self.norm_di, ilist) + + H_aux = (0.5 / self.lambda_) * self._inner_prod(R_mat, R_mat) + II_mat = self._inner_prod(R_mat, self.u_norm * R_mat) + del R_mat + + V_mat = self.V_mat[ilist, :] + diag_vvt = torch.square(V_mat).sum(dim=1).view(-1) + + G_mat = - self.DI(self.norm_di - 1, ilist).T * \ + (V_mat @ V_mat.T).clip(0).fill_diagonal_(0) * self.DI(1 - self.norm_di, ilist) + del V_mat, diag_vvt + + H_aux = H_aux + ( + self.eps * ((1 / self.lambda_) - 1) * + (II_mat @ G_mat) + ) + del G_mat + + II_inv = torch.inverse( + self.eps * II_mat + torch.diag( + self.DI(2, ilist).view(-1) * self.theta + self.rho + ) + ) + del II_mat + + H_aux = II_inv @ H_aux + Phi_mat = torch.zeros_like(H_aux, device=self.device) + S_mat = torch.zeros_like(H_aux, device=self.device) + + for _ in range(self.admm_iter): + # ADMM Iteration + H_tilde = H_aux + II_inv @ (self.rho * (S_mat - Phi_mat)) + lag_op = torch.diag(H_tilde) / (torch.diag(II_inv) + 1e-10) + H_mat = H_tilde - II_inv * lag_op # Update H + S_mat = H_mat + Phi_mat + S_mat = torch.clip((S_mat.T + S_mat) / 2, min=0) # Update S + Phi_mat += H_mat - S_mat # Update Phi + + return torch.triu(S_mat) + + def forward(self): + r""" + Abstract method of GeneralRecommender in RecBole (not used) + """ + pass + + def calculate_loss(self, interaction): + r""" + Abstract method of GeneralRecommender in RecBole (not used) + """ + return torch.nn.Parameter(torch.zeros(1)) + + def predict(self, interaction): + r""" + Abstract method of GeneralRecommender in RecBole (not used) + """ + raise NotImplementedError + + def full_sort_predict(self, interaction): + r""" + Recommend items for the input users + """ + R_mat = self.inter_mat.index_select(dim=0, index=interaction[self.USER_ID]).to_dense() + + Y_mat = R_mat * self.DI(-1) @ self.V_mat @ self.V_mat.T * self.di_sqr + Y_mat = ((1 / self.lambda_) - 1) * \ + torch.clip(Y_mat - R_mat * torch.square(self.V_mat).sum(dim=1).reshape(1, -1), min=0) + + R_mat = R_mat * self.DI(-self.norm_di) + Y_mat += ((self.H_mat @ R_mat.T).T + R_mat @ self.H_mat) * self.DI(self.norm_di) + + return Y_mat diff --git a/CoRML/utils.py b/CoRML/utils.py new file mode 100644 index 0000000..11bf1f7 --- /dev/null +++ b/CoRML/utils.py @@ -0,0 +1,125 @@ +r""" + Name: utils.py + Date: 2023/04/12 + Description: Utility functions. +""" + +import os +import shutil +import warnings +from logging import getLogger +from datetime import datetime +from recbole.config import Config +from recbole.trainer import Trainer +from recbole.data.dataset import Dataset +from recbole.data import data_preparation +from recbole.utils import init_seed +from recbole.utils.enum_type import ModelType + +from CoRML.model import CoRML + +warnings.filterwarnings('ignore') + +class RecTrainer(object): + r""" + Trainer class. + """ + def __init__(self, dataset='yelp2018'): + self.model_name = CoRML + self.dataset = dataset + self.dataset_file_list = ['Params/Overall.yaml', 'Params/{:s}.yaml'.format(dataset)] + self.preprocessing() + + def preprocessing(self): + self.config = Config(model=self.model_name, config_dict=None, config_file_list=self.dataset_file_list) + init_seed(self.config['seed'], self.config['reproducibility']) + init_logger(self.config) + self.logger = getLogger() + dataset = Dataset(self.config) + self.train_data, self.valid_data, self.test_data = data_preparation(self.config, dataset) + + def train(self, verbose=False): + self.logger.info("Start training:") + model = self.model_name(self.config, self.train_data.dataset).to(self.config['device']) + self.trainer = Trainer(self.config, model) + shutil.rmtree(self.trainer.tensorboard.log_dir) + cur_time = datetime.now().strftime('%b-%d-%Y_%H-%M-%S') + self.trainer.tensorboard = get_tensorboard(cur_time, self.config['model']) + _, _ = self.trainer.fit(self.train_data, None, verbose=verbose, show_progress=False) + load_best = False if self.trainer.model.type == ModelType.TRADITIONAL else True + self.logger.info("Start evalutaion:") + test_result = self.trainer.evaluate(self.test_data, load_best_model=load_best) + os.remove(self.trainer.saved_model_file) + shutil.rmtree(os.path.join('Log/{:s}'.format(self.config['model']), cur_time)) + if os.path.exists('log_tensorboard'): + shutil.rmtree(os.path.join('log_tensorboard')) + for k, v in test_result.items(): + self.logger.info("{:12s}: {:.6f}".format(k, v)) + +from torch.utils.tensorboard import SummaryWriter +from recbole.utils import ensure_dir + +def get_tensorboard(cur_time, model): + r""" + Modified version of get_tensorboard in Recbole. + Source: https://github.com/RUCAIBox/RecBole/blob/5d7df69bcbe9d21b4185946e8ee9a4bd8f041b9d/recbole/utils/utils.py#L206-L230 + """ + base_path = 'Log/{:s}'.format(model) + if not os.path.exists(base_path): + os.makedirs(base_path) + dir_name = cur_time + dir_path = os.path.join(base_path, dir_name) + writer = SummaryWriter(dir_path) + return writer + +import logging +import colorlog +from colorama import init +from logging import getLogger +from recbole.utils.logger import log_colors_config, RemoveColorFilter + +def init_logger(config): + r""" + Modified version of init_logger in Recbole. + Source: https://github.com/RUCAIBox/RecBole/blob/5d7df69bcbe9d21b4185946e8ee9a4bd8f041b9d/recbole/utils/logger.py#L60-L118 + """ + init(autoreset=True) + LOGROOT = './Log/' + dir_name = os.path.dirname(LOGROOT) + ensure_dir(dir_name) + + logfilename = 'RUN.log' + + logfilepath = os.path.join(LOGROOT, logfilename) + + filefmt = "%(asctime)-15s %(levelname)s %(message)s" + filedatefmt = "%a %d %b %Y %H:%M:%S" + fileformatter = logging.Formatter(filefmt, filedatefmt) + + sfmt = "%(log_color)s%(asctime)-15s %(levelname)s %(message)s" + sdatefmt = "%d %b %H:%M" + sformatter = colorlog.ColoredFormatter(sfmt, sdatefmt, log_colors=log_colors_config) + if config['state'] is None or config['state'].lower() == 'info': + level = logging.INFO + elif config['state'].lower() == 'debug': + level = logging.DEBUG + elif config['state'].lower() == 'error': + level = logging.ERROR + elif config['state'].lower() == 'warning': + level = logging.WARNING + elif config['state'].lower() == 'critical': + level = logging.CRITICAL + else: + level = logging.INFO + + fh = logging.FileHandler(logfilepath) + fh.setLevel(level) + fh.setFormatter(fileformatter) + remove_color_filter = RemoveColorFilter() + fh.addFilter(remove_color_filter) + + sh = logging.StreamHandler() + sh.setLevel(level) + sh.setFormatter(sformatter) + + logging.basicConfig(level=level, handlers=[sh, fh]) \ No newline at end of file diff --git a/Params/Overall.yaml b/Params/Overall.yaml new file mode 100644 index 0000000..59f50b3 --- /dev/null +++ b/Params/Overall.yaml @@ -0,0 +1,48 @@ +# General Parameters +seed: 2026 #Choose in [2022,2024,2026,2028,2030] +reproducibility: True + +# Training Parameters +train_neg_sample_args: + sample_num: 1 +epochs: 1 +train_batch_size: 524288 + +# Evaluation Parameters +eval_args: + split: {'RS':[0.6, 0.2, 0.2]} + order: RO + group_by: user + mode: full +eval_batch_size: 1048576 +eval_step: 0 +stopping_step: 1 +valid_metric: NDCG@20 +topk: [5, 10, 20] +metrics: ['NDCG', 'MRR'] +metric_decimal_place: 6 + +# Dataset Default Parameters +data_path: ./Data/ +load_col: + inter: [user_id, item_id] + user: [user_id] + item: [item_id] + +USER_ID_FIELD: user_id +ITEM_ID_FIELD: item_id +filter_inter_by_user_or_item: False + +# Model Default Parameters +model: CoRML +partition_ratio: 1.0 # Maximum size ratio of item partition (1.0 means no partitioning) +sparse_approx: True # Sparse approximation to reduce storage size of H to compare with MF-based models +eigenvectors: 256 # Number of eigenvectors used in SVD +admm_iter: 50 # Number of iterations for ADMM + +lambda: 0.75 # Weights for H and G in preference scores +dual_step_length: 5e3 # Dual step length of ADMM +l2_regularization: 1.0 # L2-regularization for learning weight matrix G +item_degree_norm: 0.1 # Item degree norm for learning weight matrix G +global_scaling: 0.0 # Global scaling in approximated ranking weights (in logarithm scale) +user_scaling: 0.0 # User degree scaling in approximated ranking weights diff --git a/Params/Pinterest.yaml b/Params/Pinterest.yaml new file mode 100644 index 0000000..2b227bf --- /dev/null +++ b/Params/Pinterest.yaml @@ -0,0 +1,10 @@ +# Pinterest +dataset: Pinterest + +lambda: 0.7 +dual_step_length: 5e3 +l2_regularization: 1e-3 +global_scaling: -0.5 +user_scaling: 0.5 + +partition_ratio: 1.0 \ No newline at end of file diff --git a/Params/gowalla.yaml b/Params/gowalla.yaml new file mode 100644 index 0000000..7637ceb --- /dev/null +++ b/Params/gowalla.yaml @@ -0,0 +1,10 @@ +# Gowalla +dataset: gowalla + +lambda: 0.8 +dual_step_length: 5e3 +l2_regularization: 1.0 +global_scaling: 0.0 +user_scaling: 0.0 + +partition_ratio: 0.35 \ No newline at end of file diff --git a/Params/ml-20m.yaml b/Params/ml-20m.yaml new file mode 100644 index 0000000..92cd690 --- /dev/null +++ b/Params/ml-20m.yaml @@ -0,0 +1,10 @@ +# ML-20M +dataset: ml-20m + +lambda: 0.8 +dual_step_length: 10.0 +l2_regularization: 0.1 +global_scaling: 0.0 +user_scaling: 0.0 + +partition_ratio: 1.0 \ No newline at end of file diff --git a/Params/yelp2018.yaml b/Params/yelp2018.yaml new file mode 100644 index 0000000..b27978a --- /dev/null +++ b/Params/yelp2018.yaml @@ -0,0 +1,10 @@ +# Yelp2018 +dataset: yelp2018 + +lambda: 0.8 +dual_step_length: 5e3 +l2_regularization: 1.0 +global_scaling: -1.0 +user_scaling: 0.5 + +partition_ratio: 0.3 \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..110ff62 --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +# CoRML +[![RecBole](https://img.shields.io/badge/RecBole-1.1.1-orange)](https://recbole.io/) +[![arXiv](https://img.shields.io/badge/arXiv-2304.07971-red)](https://arxiv.org/abs/2304.07971) +[![License](https://img.shields.io/github/license/Joinn99/CoRML)](https://github.com/Joinn99/CoRML/blob/torch/LICENSE.md) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Joinn99/CoRML/blob/torch/CoRML%20Demonstration.ipynb) + +**PyTorch version** (Default) | [**CuPy version**](https://github.com/Joinn99/CoRML/tree/cupy) + +This is the official implementation of our *The 46th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR 2023)* paper: +> Tianjun Wei, Jianghong Ma, Tommy W.S. Chow. Collaborative Residual Metric Learning. [[arXiv](https://arxiv.org/abs/2304.07971)] + +

+ +

+ + +## Requirements +The model implementation ensures compatibility with the Recommendation Toolbox [RecBole](https://recbole.io/) (Github: [Recbole](https://github.com/RUCAIBox/RecBole)). + +The requirements of the running environement: + +- Python: 3.8+ +- PyTorch: 1.9.0+ +- RecBole: 1.1.1 + +## Dataset +Here we only put zip files of datasets in the respository due to the storage limits. To use the dataset, run +```bash +unzip -o "Data/*.zip" +``` +to unzip the dataset files. + +If you like to test CoRML on the custom dataset, please place the dataset files in the following path: +```bash +. +|-Data +| |-[CUSTOM_DATASET_NAME] +| | |-[CUSTOM_DATASET_NAME].user +| | |-[CUSTOM_DATASET_NAME].item +| | |-[CUSTOM_DATASET_NAME].inter + +``` +And create `[CUSTOM_DATASET_NAME].yaml` in `./Params` with the following content: +```yaml +dataset: [CUSTOM_DATASET_NAME] +``` + +For the format of each dataset file, please refer to [RecBole API](https://recbole.io/docs/user_guide/data/atomic_files.html). + +## Hyperparameter +For each dataset, the optimal hyperparameters are stored in `Params/[DATASET].yaml`. To tune the hyperparamters, modify the corresponding values in the file for each dataset. + +The main hyparameters of CoRML are listed as follows: + - `lambda` ($\lambda$): Weights for $\mathbf{H}$ and $\mathbf{G}$ in preference scores + - `dual_step_length` ($\rho$): Dual step length of ADMM + - `l2_regularization` ($\theta$): $L2-regularization for learning weight matrix $\mathbf{H}$ + - `item_degree_norm` ($t$): Item degree norm for learning weight matrix $\mathbf{H}$ + - `global_scaling` ($\epsilon$): Global scaling in approximated ranking weights (in logarithm scale) + - `user_scaling` ($t_u$): User degree scaling in approximated ranking weights + +### Running on small partitions +For datasets containing a large number of items, calculating and storing the complete item-item matrix may lead to out-of-memory problem in environments with small GPU memory. Therefore, we have added the code for graph spectral partitioning to learn the item-item weight matrix on each small partitioned item set. The code was modified based on our previous work [FPSR](https://github.com/Joinn99/FPSR). + +Hyperparameter `partition_ratio` is used to control the maximum ratio of the partitioned set of items relative to the complete set, ranging from 0 to 1. When `partition_ratio` is set to 1, no partitioning will be performed. + +### Sparse approximation +To maintain consistency, we perform a sparse approximation of the derived matrice $\mathbf{H}$ in CoRML by setting the entries to 0 where $\lvert \mathbf{H} \rvert \leq \eta$. The threshold $\eta$ will be adjusted so that the storage size of the sparse matrix $\mathbf{H}_{Sparse}$ is less than other types of models with embedding size 64. + +Hyperparameter `sparse_approx` is used to control the sparse approximation. When `sparse_approx` is set to `False`, no sparse approximation will be performed. + +## Running +The script `run.py` is used to reproduced the results presented in paper. Train and evaluate CoRML on a specific dataset, run +```bash +python run.py --dataset DATASET_NAME +``` + +## Google Colab +We also provide Colab notebook version of CoRML, you can click [here](https://colab.research.google.com/github/Joinn99/CoRML/blob/torch/CoRML%20Demonstration.ipynb) to open Google Colab, select the runtime type as *GPU*, and run the model. + +## Citation +If you wish, please cite the following paper: + +```bibtex + +@InProceedings{CoRML, + author = {{Wei}, Tianjun and {Ma}, Jianghong and {Chow}, Tommy W.~S.}, + booktitle = {Proceedings of the 46th International ACM SIGIR Conference on Research and Development in Information Retrieval}, + title = {Collaborative Residual Metric Learning}, + year = {2023}, + address = {New York, NY, USA}, + publisher = {Association for Computing Machinery}, + series = {SIGIR '23}, + doi = {10.1145/3539618.3591649}, + location = {Taipei, Taiwan}, + numpages = {10}, + url = {https://doi.org/10.1145/3539618.3591649}, +} + +@InProceedings{FPSR, + author = {{Wei}, Tianjun and {Ma}, Jianghong and {Chow}, Tommy W.~S.}, + booktitle = {Proceedings of the ACM Web Conference 2023}, + title = {Fine-tuning Partition-aware Item Similarities for Efficient and Scalable Recommendation}, + year = {2023}, + address = {New York, NY, USA}, + publisher = {Association for Computing Machinery}, + series = {WWW '23}, + doi = {10.1145/3543507.3583240}, + location = {Austin, TX, USA}, + numpages = {11}, + url = {https://doi.org/10.1145/3543507.3583240}, +} +``` +## License +This project is licensed under the terms of the MIT license. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4da0868 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +Recbole==1.1.1 +torch>=1.9.0 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..40cba5e --- /dev/null +++ b/run.py @@ -0,0 +1,9 @@ +import argparse +from CoRML.utils import RecTrainer + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', '-d', type=str, default='yelp2018', help='Name of Datasets') + args, _ = parser.parse_known_args() + + RecTrainer(dataset=args.dataset).train(verbose=False) \ No newline at end of file