diff --git a/CoRML Demonstration.ipynb b/CoRML Demonstration.ipynb
new file mode 100644
index 0000000..5c49b47
--- /dev/null
+++ b/CoRML Demonstration.ipynb
@@ -0,0 +1,597 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1cN-BpEGaSf5"
+ },
+ "source": [
+ "# CoRML\n",
+ "------------\n",
+ "This is an demonstration of our paper:\n",
+ "\n",
+ "> Tianjun Wei, Jianghong Ma, Tommy W.S. Chow. Collaborative Residual Metric Learning. In SIGIR 2023. [[arxiv](https://arxiv.org/abs/2304.07971)] [[Github](https://github.com/Joinn99/CoRML)]\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8pxSpTa9ZoTy"
+ },
+ "source": [
+ "## Preprocessing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "cellView": "form",
+ "id": "FrbhQmJjGF3I"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Environment Setting\n",
+ "#@markdown CoRML is implemented based on the recommendation toolbox [RecBole](https://recbole.io/).\n",
+ "%%capture\n",
+ "! pip install recbole==1.1.1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "cellView": "form",
+ "id": "pqa6o51yExY_"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Download Dataset\n",
+ "#@markdown Here we use four public datasets:
\n",
+ "#@markdown - [ML-20M](http://files.grouplens.org/datasets/movielens/)\n",
+ "#@markdown - [Pinterest](https://github.com/hexiangnan/neural_collaborative_filtering/blob/master/Data/)\n",
+ "#@markdown - [Gowalla](https://github.com/kuandeng/LightGCN/blob/master/Data/gowalla/)\n",
+ "#@markdown - [Yelp2018](https://github.com/kuandeng/LightGCN/blob/master/Data/yelp2018)\n",
+ "\n",
+ "#@markdown We also provide processed datasets that conform to the RecBole data format, which can be downloaded [here](https://github.com/Joinn99/CoRML/tree/torch/Data).\n",
+ "\n",
+ "%%capture\n",
+ "! git clone https://github.com/Joinn99/CoRML\n",
+ "! cd CoRML && unzip -o \"Data/*.zip\" -d /content"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "7NnQ5JnKb-p1"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Configuation\n",
+ "#@markdown Here we list the configurations of the experiment.\n",
+ "#@markdown For a detailed description of the configuration, please refer to [RecBole API](https://recbole.io/docs/user_guide/config_settings.html).\n",
+ "\n",
+ "CONFIG = {\n",
+ " # General Parameters\n",
+ " \"seed\": 2026, #Choose in [2022,2024,2026,2028,2030]\n",
+ " \"reproducibility\": True,\n",
+ "\n",
+ " # Training Parameters\n",
+ " \"train_neg_sample_args\":{\n",
+ " \"sample_num\": 1\n",
+ " },\n",
+ " \"epochs\": 1,\n",
+ " \"train_batch_size\": 524288,\n",
+ "\n",
+ "\n",
+ " # Evaluation Parameters\n",
+ " \"eval_args\":{\n",
+ " \"split\": {'RS':[0.6, 0.2, 0.2]},\n",
+ " \"order\": \"RO\",\n",
+ " \"group_by\": \"user\",\n",
+ " \"mode\": \"full\"\n",
+ " },\n",
+ " \"eval_batch_size\": 1048576,\n",
+ " \"eval_step\": 0,\n",
+ " \"stopping_step\": 1,\n",
+ " \"valid_metric\": \"NDCG@20\",\n",
+ " \"topk\": [5, 10, 20],\n",
+ " \"metrics\": [\"NDCG\", \"MRR\"] ,\n",
+ " \"metric_decimal_place\": 6,\n",
+ "\n",
+ " # Dataset Default Parameters\n",
+ " \"data_path\": \"Data\",\n",
+ " \"load_col\":{\n",
+ " \"inter\": [\"user_id\", \"item_id\"],\n",
+ " \"user\": [\"user_id\"],\n",
+ " \"item\": [\"item_id\"],\n",
+ " },\n",
+ "\n",
+ " \"USER_ID_FIELD\": \"user_id\",\n",
+ " \"ITEM_ID_FIELD\": \"item_id\",\n",
+ " \"filter_inter_by_user_or_item\": False,\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f5ytkUIsbn_B"
+ },
+ "source": [
+ "### Hyperparameters\n",
+ "Here we list the hyperparameter setting of FPSR. To evaluate the performance of FPSR with different hyperparameters, modify the values in `FPSR_HYPER` and select `Runtime->Run after`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "k5R7vKwlMPfO"
+ },
+ "outputs": [],
+ "source": [
+ "DATASET = {\n",
+ " \"dataset\": \"yelp2018\"\n",
+ "}\n",
+ "\n",
+ "CoRML_HYPER = {\n",
+ " # Model Default Parameters\n",
+ " \"partition_ratio\": 0.3, # Maximum size ratio of item partition (1.0 means no partitioning)\n",
+ " \"sparse_approx\": True, # Sparse approximation to reduce storage size of H to compare with MF-based models\n",
+ " \"eigenvectors\": 256, # Number of eigenvectors used in SVD\n",
+ " \"admm_iter\": 50,\n",
+ "\n",
+ " \"lambda\": 0.8, # Weights for H and G in preference scores\n",
+ " \"dual_step_length\": 5e3, # Dual step length of ADMM\n",
+ " \"l2_regularization\": 1.0, # L2-regularization for learning weight matrix G\n",
+ " \"item_degree_norm\": 0.1, # Item degree norm for learning weight matrix G\n",
+ " \"global_scaling\": -1.0, # Global scaling in approximated ranking weights (in logarithm scale)\n",
+ " \"user_scaling\": 0.5, # User degree scaling in approximated ranking weights\n",
+ "}\n",
+ "\n",
+ "CONFIG.update(DATASET)\n",
+ "CONFIG.update(CoRML_HYPER)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7hNBSzOCZszh"
+ },
+ "source": [
+ "## Model Implementation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "cellView": "form",
+ "id": "jHuhZcigM0lJ"
+ },
+ "outputs": [],
+ "source": [
+ "r\"\"\"\n",
+ "CoRML (PyTorch Version)\n",
+ "################################################\n",
+ "Author:\n",
+ " Tianjun Wei (tjwei2-c@my.cityu.edu.hk)\n",
+ "Reference:\n",
+ " Tianjun Wei et al. \"Collaborative Residual Metric Learning.\" in SIGIR 2023.\n",
+ "Created Date:\n",
+ " 2023/04/10\n",
+ "\"\"\"\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "from recbole.utils import InputType\n",
+ "from recbole.utils.enum_type import ModelType\n",
+ "from recbole.model.abstract_recommender import GeneralRecommender\n",
+ "\n",
+ "\n",
+ "class SpectralInfo(object):\n",
+ " r\"\"\"A class for producing spectral information of the interaction matrix, including node degrees,\n",
+ " singular vectors, and partitions produced by spectral graph partitioning.\n",
+ "\n",
+ " Reference: https://github.com/Joinn99/FPSR/blob/torch/FPSR/model.py\n",
+ " \"\"\"\n",
+ " def __init__(self, inter_mat, config) -> None:\n",
+ " # load parameters info\n",
+ " self.eigen_dim = config[\"eigenvectors\"] # Number of eigenvectors used in SVD\n",
+ " self.t_u = config[\"user_scaling\"]\n",
+ " self.norm_di = 2 * config[\"item_degree_norm\"]\n",
+ " self.tau = config[\"partition_ratio\"] # Maximum size ratio of item partition (1.0 means no partitioning)\n",
+ " self.inter_mat = inter_mat\n",
+ "\n",
+ " def _degree(self, inter_mat=None, dim=0, exp=-0.5) -> torch.Tensor:\n",
+ " r\"\"\"Get the degree of users and items.\n",
+ " \n",
+ " Returns:\n",
+ " Tensor of the node degrees.\n",
+ " \"\"\"\n",
+ " if inter_mat is None:\n",
+ " inter_mat = self.inter_mat\n",
+ " d_inv = torch.nan_to_num(\n",
+ " torch.clip(torch.sparse.sum(inter_mat, dim=dim).to_dense(), min=1.).pow(exp), nan=1., posinf=1., neginf=1.\n",
+ " )\n",
+ " return d_inv\n",
+ "\n",
+ " def _svd(self, mat, k) -> torch.Tensor:\n",
+ " r\"\"\"Perform Truncated singular value decomposition (SVD) on\n",
+ " the input matrix, return top-k eigenvectors.\n",
+ " \n",
+ " Returns:\n",
+ " Tok-k eigenvectors.\n",
+ " \"\"\"\n",
+ " _, _, V = torch.svd_lowrank(mat, q=max(4*k, 32), niter=10)\n",
+ " return V[:, :k]\n",
+ "\n",
+ " def _norm_adj(self, item_list=None) -> torch.Tensor:\n",
+ " r\"\"\"Get the normalized item-item adjacency matrix for a group of items.\n",
+ " \n",
+ " Returns:\n",
+ " Sparse tensor of the normalized item-item adjacency matrix.\n",
+ " \"\"\"\n",
+ " if item_list is None:\n",
+ " vals = self.inter_mat.values() * self.di_isqr[self.inter_mat.indices()[1]].squeeze()\n",
+ " return torch.sparse_coo_tensor(\n",
+ " self.inter_mat.indices(),\n",
+ " self._degree(dim=1)[self.inter_mat.indices()[0]] * vals,\n",
+ " size=self.inter_mat.shape, dtype=torch.float\n",
+ " ).coalesce()\n",
+ " else:\n",
+ " inter = self.inter_mat.index_select(dim=1, index=item_list).coalesce()\n",
+ " vals = inter.values() * self.di_isqr[item_list][inter.indices()[1]].squeeze()\n",
+ " return torch.sparse_coo_tensor(\n",
+ " inter.indices(),\n",
+ " self._degree(inter, dim=1)[inter.indices()[0]] * vals,\n",
+ " size=inter.shape, dtype=torch.float\n",
+ " ).coalesce()\n",
+ "\n",
+ " def run(self):\n",
+ " r\"\"\"\n",
+ " Spectral information\n",
+ " \"\"\"\n",
+ " self.di_isqr = self._degree(dim=0).reshape(-1, 1)\n",
+ " self.di_sqr = self._degree(dim=0, exp=0.5).reshape(1, -1)\n",
+ "\n",
+ " u_norm = self._degree(dim=1, exp=-self.t_u).reshape(-1, 1)\n",
+ " self.u_norm = u_norm / u_norm.min()\n",
+ "\n",
+ " self.V_mat = self._svd(self._norm_adj(), self.eigen_dim)\n",
+ "\n",
+ " return self.di_sqr, self.u_norm, self.V_mat\n",
+ "\n",
+ " def partitioning(self, V) -> torch.Tensor:\n",
+ " r\"\"\"\n",
+ " Graph bipartitioning\n",
+ " \"\"\"\n",
+ " split = V[:, 1] >= 0\n",
+ " if split.sum() == split.shape[0] or split.sum() == 0:\n",
+ " split = V[:, 1] >= torch.median(V[:, 1])\n",
+ " return split\n",
+ "\n",
+ " def get_partition(self, ilist, total_num):\n",
+ " r\"\"\"\n",
+ " Get partitions of item-item graph\n",
+ " \"\"\"\n",
+ " if ilist.shape[0] <= total_num * self.tau:\n",
+ " return [ilist]\n",
+ " else:\n",
+ " # If the partition size is larger than size limit,\n",
+ " # perform graph partitioning on this partition.\n",
+ " split = self.partitioning(self._svd(self._norm_adj(ilist), 2))\n",
+ " return self.get_partition(ilist[torch.where(split)[0]], total_num) + \\\n",
+ " self.get_partition(ilist[torch.where(~split)[0]], total_num)\n",
+ "\n",
+ "\n",
+ "class CoRML(GeneralRecommender):\n",
+ " r\"\"\"CoRML is an item-based metric learning model for collaborative filtering.\n",
+ "\n",
+ " CoRML learn a generalized distance user-item distance metric to capture user\n",
+ " preference in user-item interaction signals by modeling the residuals of general\n",
+ " Mahalanobis distance.\n",
+ " \"\"\"\n",
+ " input_type = InputType.POINTWISE\n",
+ " type = ModelType.TRADITIONAL\n",
+ "\n",
+ " def __init__(self, config, dataset):\n",
+ " r\"\"\"\n",
+ " Model initialization and training.\n",
+ " \"\"\"\n",
+ " super().__init__(config, dataset)\n",
+ "\n",
+ " self.lambda_ = config[\"lambda\"] # Weights for H and G in preference scores\n",
+ " self.rho = config[\"dual_step_length\"] # Dual step length of ADMM\n",
+ " self.theta = config[\"l2_regularization\"] # L2-regularization for learning weight matrix G\n",
+ " self.norm_di = 2 * config[\"item_degree_norm\"] # Item degree norm for learning weight matrix G\n",
+ " self.eps = np.power(10, config[\"global_scaling\"]) # Global scaling in approximated ranking weights (in logarithm scale)\n",
+ "\n",
+ " self.sparse_approx = config[\"sparse_approx\"] # Sparse approximation to reduce storage size of H\n",
+ " self.admm_iter = config[\"admm_iter\"] # Number of iterations for ADMM\n",
+ "\n",
+ " # Dummy pytorch parameters required by Recbole\n",
+ " self.dummy_param = torch.nn.Parameter(torch.zeros(1))\n",
+ "\n",
+ " # User-item interaction matrix\n",
+ " self.inter_mat = dataset.inter_matrix(form='coo')\n",
+ " self.inter_mat = torch.sparse_coo_tensor(\n",
+ " torch.LongTensor(np.array([self.inter_mat.row, self.inter_mat.col])),\n",
+ " torch.FloatTensor(self.inter_mat.data),\n",
+ " size=self.inter_mat.shape, dtype=torch.float\n",
+ " ).coalesce().to(self.device)\n",
+ "\n",
+ " # TRAINING PROCESS\n",
+ " item_list = self.update_G(config)\n",
+ " self.update_H(item_list)\n",
+ "\n",
+ " def DI(self, pow=1., ilist=None):\n",
+ " r\"\"\"\n",
+ " Degree of item node\n",
+ " \"\"\"\n",
+ " if ilist is not None:\n",
+ " return torch.pow(self.di_sqr[:, ilist], pow)\n",
+ " else:\n",
+ " return torch.pow(self.di_sqr, pow)\n",
+ "\n",
+ " def update_G(self, config):\n",
+ " r\"\"\"\n",
+ " Update G matrix\n",
+ " \"\"\"\n",
+ " G = SpectralInfo(self.inter_mat, config)\n",
+ " self.di_sqr, self.u_norm, self.V_mat = G.run()\n",
+ " item_list = G.get_partition(\n",
+ " torch.arange(self.n_items, device=self.device), self.n_items\n",
+ " )\n",
+ " return item_list\n",
+ "\n",
+ " def update_H(self, item_list):\n",
+ " r\"\"\"\n",
+ " Update H matrix\n",
+ " \"\"\"\n",
+ " self.H_indices = []\n",
+ " self.H_values = []\n",
+ "\n",
+ " for ilist in tqdm(item_list, desc=\"Partition\", bar_format=\"{elapsed}\"):\n",
+ " H_triu = self.update_H_part(ilist)\n",
+ " H_triu = torch.where(H_triu >= 5e-4, H_triu, 0).to_sparse_coo()\n",
+ " self.H_indices.append(ilist[H_triu.indices()])\n",
+ " self.H_values.append(H_triu.values())\n",
+ " \n",
+ " H_mat = torch.sparse_coo_tensor(indices=torch.cat(self.H_indices, dim=1),\n",
+ " values=torch.cat(self.H_values, dim=0),\n",
+ " size=(self.n_items, self.n_items)).coalesce()\n",
+ " del self.H_indices, self.H_values\n",
+ " \n",
+ " # Sparse approximation\n",
+ " if self.sparse_approx:\n",
+ " limit = (self.n_users + self.n_items) * 64 # Embedding size in MF models\n",
+ " thres = 1e-4\n",
+ " while (H_mat._nnz() + H_mat.indices().shape[-1] + self.n_items + 1) >= limit:\n",
+ " mask = torch.where(H_mat.values() > thres)[0]\n",
+ " H_mat = torch.sparse_coo_tensor(indices=H_mat.indices().index_select(-1, mask),\n",
+ " values=H_mat.values().index_select(-1, mask),\n",
+ " size=(self.n_items, self.n_items)).coalesce()\n",
+ " thres *= 1.25\n",
+ " self.H_mat = H_mat.T.to_sparse_csr()\n",
+ "\n",
+ " def _inner_prod(self, A_mat: torch.Tensor, B_mat: torch.Tensor):\n",
+ " r\"\"\"\n",
+ " Small-batch inner product\n",
+ " \"\"\"\n",
+ " assert A_mat.shape[-2] == B_mat.shape[-2]\n",
+ " result = torch.zeros((A_mat.shape[-1], B_mat.shape[-1]), device=self.device)\n",
+ " for chunk in torch.split(torch.arange(0, A_mat.shape[-2], device=self.device), 10000):\n",
+ " result += A_mat.index_select(dim=-2, index=chunk).to_dense().T @ \\\n",
+ " B_mat.index_select(dim=-2, index=chunk).to_dense()\n",
+ " return result\n",
+ "\n",
+ " def update_H_part(self, ilist):\n",
+ " r\"\"\"\n",
+ " Learning H in each partition (if any)\n",
+ " \"\"\"\n",
+ " R_mat = self.inter_mat.index_select(dim=1, index=ilist) * self.DI(-self.norm_di, ilist)\n",
+ "\n",
+ " H_aux = (0.5 / self.lambda_) * self._inner_prod(R_mat, R_mat)\n",
+ " II_mat = self._inner_prod(R_mat, self.u_norm * R_mat)\n",
+ " del R_mat\n",
+ "\n",
+ " V_mat = self.V_mat[ilist, :]\n",
+ " diag_vvt = torch.square(V_mat).sum(dim=1).view(-1)\n",
+ "\n",
+ " G_mat = - self.DI(self.norm_di - 1, ilist).T * \\\n",
+ " (V_mat @ V_mat.T).clip(0).fill_diagonal_(0) * self.DI(1 - self.norm_di, ilist)\n",
+ " del V_mat, diag_vvt\n",
+ "\n",
+ " H_aux = H_aux + (\n",
+ " self.eps * ((1 / self.lambda_) - 1) *\n",
+ " (II_mat @ G_mat)\n",
+ " )\n",
+ " del G_mat\n",
+ "\n",
+ " II_inv = torch.inverse(\n",
+ " self.eps * II_mat + torch.diag(\n",
+ " self.DI(2, ilist).view(-1) * self.theta + self.rho\n",
+ " )\n",
+ " )\n",
+ " del II_mat\n",
+ "\n",
+ " H_aux = II_inv @ H_aux\n",
+ " Phi_mat = torch.zeros_like(H_aux, device=self.device)\n",
+ " S_mat = torch.zeros_like(H_aux, device=self.device)\n",
+ "\n",
+ " for _ in range(self.admm_iter):\n",
+ " # ADMM Iteration\n",
+ " H_tilde = H_aux + II_inv @ (self.rho * (S_mat - Phi_mat))\n",
+ " lag_op = torch.diag(H_tilde) / (torch.diag(II_inv) + 1e-10)\n",
+ " H_mat = H_tilde - II_inv * lag_op # Update H\n",
+ " S_mat = H_mat + Phi_mat \n",
+ " S_mat = torch.clip((S_mat.T + S_mat) / 2, min=0) # Update S\n",
+ " Phi_mat += H_mat - S_mat # Update Phi\n",
+ "\n",
+ " return torch.triu(S_mat)\n",
+ "\n",
+ " def forward(self):\n",
+ " r\"\"\"\n",
+ " Abstract method of GeneralRecommender in RecBole (not used)\n",
+ " \"\"\"\n",
+ " pass\n",
+ "\n",
+ " def calculate_loss(self, interaction):\n",
+ " r\"\"\"\n",
+ " Abstract method of GeneralRecommender in RecBole (not used)\n",
+ " \"\"\"\n",
+ " return torch.nn.Parameter(torch.zeros(1))\n",
+ "\n",
+ " def predict(self, interaction):\n",
+ " r\"\"\"\n",
+ " Abstract method of GeneralRecommender in RecBole (not used)\n",
+ " \"\"\"\n",
+ " raise NotImplementedError\n",
+ "\n",
+ " def full_sort_predict(self, interaction):\n",
+ " r\"\"\"\n",
+ " Recommend items for the input users\n",
+ " \"\"\"\n",
+ " R_mat = self.inter_mat.index_select(dim=0, index=interaction[self.USER_ID]).to_dense()\n",
+ "\n",
+ " Y_mat = R_mat * self.DI(-1) @ self.V_mat @ self.V_mat.T * self.di_sqr\n",
+ " Y_mat = ((1 / self.lambda_) - 1) * \\\n",
+ " torch.clip(Y_mat - R_mat * torch.square(self.V_mat).sum(dim=1).reshape(1, -1), min=0)\n",
+ "\n",
+ " R_mat = R_mat * self.DI(-self.norm_di)\n",
+ " Y_mat += ((self.H_mat @ R_mat.T).T + R_mat @ self.H_mat) * self.DI(self.norm_di)\n",
+ "\n",
+ " return Y_mat\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9spbyRBwb0dn"
+ },
+ "source": [
+ "## Running"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "htpZ3NxHYlKP",
+ "outputId": "04d2acdd-b84a-4367-a776-517143fa5513"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==========Preprocessing==========\n",
+ "Configuration loading complete.\n",
+ "Data loading complete.\n",
+ "============Training=============\n",
+ "(This step may last several minutes)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "01:51\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training complete.\n",
+ "===========Evaluation============\n",
+ "Evaluation results:\n",
+ "NDCG@5 : 0.069239\n",
+ "NDCG@10 : 0.071014\n",
+ "NDCG@20 : 0.081999\n",
+ "MRR@5 : 0.144153\n",
+ "MRR@10 : 0.158800\n",
+ "MRR@20 : 0.167767\n"
+ ]
+ }
+ ],
+ "source": [
+ "r\"\"\"\n",
+ "Code reference: https://recbole.io/docs/developer_guide/customize_models.html\n",
+ "\"\"\"\n",
+ "\n",
+ "import warnings\n",
+ "from recbole.utils import init_seed\n",
+ "from recbole.trainer import Trainer\n",
+ "from recbole.config import Config\n",
+ "from recbole.data import create_dataset, data_preparation\n",
+ "\n",
+ "warnings.filterwarnings('ignore')\n",
+ "\n",
+ "print('='* 10 + 'Preprocessing' + '=' * 10)\n",
+ "config = Config(model=CoRML, config_dict=CONFIG)\n",
+ "init_seed(config['seed'], config['reproducibility'])\n",
+ "print(\"Configuration loading complete.\")\n",
+ "\n",
+ "# dataset filtering\n",
+ "dataset = create_dataset(config)\n",
+ "train_data, valid_data, test_data = data_preparation(config, dataset)\n",
+ "print(\"Data loading complete.\")\n",
+ "\n",
+ "print('='* 12 + 'Training' + '=' * 13)\n",
+ "print(\"(This step may last several minutes)\")\n",
+ "# model loading and initialization\n",
+ "model = CoRML(config, train_data.dataset).to(config['device'])\n",
+ "\n",
+ "# trainer loading and initialization\n",
+ "trainer = Trainer(config, model)\n",
+ "\n",
+ "# model training\n",
+ "_, _ = trainer.fit(train_data, valid_data)\n",
+ "\n",
+ "print(\"Training complete.\")\n",
+ "\n",
+ "print('='* 11 + 'Evaluation' + '=' * 12)\n",
+ "# model evaluation\n",
+ "test_result = trainer.evaluate(test_data)\n",
+ "print(\"Evaluation results:\")\n",
+ "for metric, value in test_result.items():\n",
+ " print('{:10s}: {:.6f}'.format(metric.upper(), value))"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "SparseRec",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.11.0"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "933075f399f88ee3a7bc96289e7bedd6322a6307b0e261e2bc2c90799eef1243"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/CoRML/__init__.py b/CoRML/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/CoRML/model.py b/CoRML/model.py
new file mode 100644
index 0000000..19e0970
--- /dev/null
+++ b/CoRML/model.py
@@ -0,0 +1,294 @@
+r"""
+CoRML (PyTorch Version)
+################################################
+Author:
+ Tianjun Wei (tjwei2-c@my.cityu.edu.hk)
+Reference:
+ Tianjun Wei et al. "Collaborative Residual Metric Learning." in SIGIR 2023.
+Created Date:
+ 2023/04/10
+"""
+import torch
+import numpy as np
+
+from recbole.utils import InputType
+from recbole.utils.enum_type import ModelType
+from recbole.model.abstract_recommender import GeneralRecommender
+
+
+class SpectralInfo(object):
+ r"""A class for producing spectral information of the interaction matrix, including node degrees,
+ singular vectors, and partitions produced by spectral graph partitioning.
+
+ Reference: https://github.com/Joinn99/FPSR/blob/torch/FPSR/model.py
+ """
+ def __init__(self, inter_mat, config) -> None:
+ # load parameters info
+ self.eigen_dim = config["eigenvectors"] # Number of eigenvectors used in SVD
+ self.t_u = config["user_scaling"]
+ self.norm_di = 2 * config["item_degree_norm"]
+ self.partition_ratio = config["partition_ratio"] # Maximum size ratio of item partition (1.0 means no partitioning)
+ self.inter_mat = inter_mat
+
+ def _degree(self, inter_mat=None, dim=0, exp=-0.5) -> torch.Tensor:
+ r"""Get the degree of users and items.
+
+ Returns:
+ Tensor of the node degrees.
+ """
+ if inter_mat is None:
+ inter_mat = self.inter_mat
+ d_inv = torch.nan_to_num(
+ torch.clip(torch.sparse.sum(inter_mat, dim=dim).to_dense(), min=1.).pow(exp), nan=1., posinf=1., neginf=1.
+ )
+ return d_inv
+
+ def _svd(self, mat, k) -> torch.Tensor:
+ r"""Perform Truncated singular value decomposition (SVD) on
+ the input matrix, return top-k eigenvectors.
+
+ Returns:
+ Tok-k eigenvectors.
+ """
+ _, _, V = torch.svd_lowrank(mat, q=max(4*k, 32), niter=10)
+ return V[:, :k]
+
+ def _norm_adj(self, item_list=None) -> torch.Tensor:
+ r"""Get the normalized item-item adjacency matrix for a group of items.
+
+ Returns:
+ Sparse tensor of the normalized item-item adjacency matrix.
+ """
+ if item_list is None:
+ vals = self.inter_mat.values() * self.di_isqr[self.inter_mat.indices()[1]].squeeze()
+ return torch.sparse_coo_tensor(
+ self.inter_mat.indices(),
+ self._degree(dim=1)[self.inter_mat.indices()[0]] * vals,
+ size=self.inter_mat.shape, dtype=torch.float
+ ).coalesce()
+ else:
+ inter = self.inter_mat.index_select(dim=1, index=item_list).coalesce()
+ vals = inter.values() * self.di_isqr[item_list][inter.indices()[1]].squeeze()
+ return torch.sparse_coo_tensor(
+ inter.indices(),
+ self._degree(inter, dim=1)[inter.indices()[0]] * vals,
+ size=inter.shape, dtype=torch.float
+ ).coalesce()
+
+ def run(self):
+ r"""
+ Spectral information
+ """
+ self.di_isqr = self._degree(dim=0).reshape(-1, 1)
+ self.di_sqr = self._degree(dim=0, exp=0.5).reshape(1, -1)
+
+ u_norm = self._degree(dim=1, exp=-self.t_u).reshape(-1, 1)
+ self.u_norm = u_norm / u_norm.min()
+
+ self.V_mat = self._svd(self._norm_adj(), self.eigen_dim)
+
+ return self.di_sqr, self.u_norm, self.V_mat
+
+ def partitioning(self, V) -> torch.Tensor:
+ r"""
+ Graph bipartitioning
+ """
+ split = V[:, 1] >= 0
+ if split.sum() == split.shape[0] or split.sum() == 0:
+ split = V[:, 1] >= torch.median(V[:, 1])
+ return split
+
+ def get_partition(self, ilist, total_num):
+ r"""
+ Get partitions of item-item graph
+ """
+ assert self.partition_ratio > (1 / total_num)
+
+ if ilist.shape[0] <= total_num * self.partition_ratio:
+ return [ilist]
+ else:
+ # If the partition size is larger than size limit,
+ # perform graph partitioning on this partition.
+ split = self.partitioning(self._svd(self._norm_adj(ilist), 2))
+ return self.get_partition(ilist[torch.where(split)[0]], total_num) + \
+ self.get_partition(ilist[torch.where(~split)[0]], total_num)
+
+
+class CoRML(GeneralRecommender):
+ r"""CoRML is an item-based metric learning model for collaborative filtering.
+
+ CoRML learn a generalized distance user-item distance metric to capture user
+ preference in user-item interaction signals by modeling the residuals of general
+ Mahalanobis distance.
+ """
+ input_type = InputType.POINTWISE
+ type = ModelType.TRADITIONAL
+
+ def __init__(self, config, dataset):
+ r"""
+ Model initialization and training.
+ """
+ super().__init__(config, dataset)
+
+ self.lambda_ = config["lambda"] # Weights for H and G in preference scores
+ self.rho = config["dual_step_length"] # Dual step length of ADMM
+ self.theta = config["l2_regularization"] # L2-regularization for learning weight matrix G
+ self.norm_di = 2 * config["item_degree_norm"] # Item degree norm for learning weight matrix G
+ self.eps = np.power(10, config["global_scaling"]) # Global scaling in approximated ranking weights (in logarithm scale)
+
+ self.sparse_approx = config["sparse_approx"] # Sparse approximation to reduce storage size of H
+ self.admm_iter = config["admm_iter"] # Number of iterations for ADMM
+
+ # Dummy pytorch parameters required by Recbole
+ self.dummy_param = torch.nn.Parameter(torch.zeros(1))
+
+ # User-item interaction matrix
+ self.inter_mat = dataset.inter_matrix(form='coo')
+ self.inter_mat = torch.sparse_coo_tensor(
+ torch.LongTensor(np.array([self.inter_mat.row, self.inter_mat.col])),
+ torch.FloatTensor(self.inter_mat.data),
+ size=self.inter_mat.shape, dtype=torch.float
+ ).coalesce().to(self.device)
+
+ # TRAINING PROCESS
+ item_list = self.update_G(config)
+ self.update_H(item_list)
+
+ def DI(self, pow=1., ilist=None):
+ r"""
+ Degree of item node
+ """
+ if ilist is not None:
+ return torch.pow(self.di_sqr[:, ilist], pow)
+ else:
+ return torch.pow(self.di_sqr, pow)
+
+ def update_G(self, config):
+ r"""
+ Update G matrix
+ """
+ G = SpectralInfo(self.inter_mat, config)
+ self.di_sqr, self.u_norm, self.V_mat = G.run()
+ item_list = G.get_partition(
+ torch.arange(self.n_items, device=self.device), self.n_items
+ )
+ return item_list
+
+ def update_H(self, item_list):
+ r"""
+ Update H matrix
+ """
+ self.H_indices = []
+ self.H_values = []
+
+ for ilist in item_list:
+ H_triu = self.update_H_part(ilist)
+ H_triu = torch.where(H_triu >= 5e-4, H_triu, 0).to_sparse_coo()
+ self.H_indices.append(ilist[H_triu.indices()])
+ self.H_values.append(H_triu.values())
+
+ H_mat = torch.sparse_coo_tensor(indices=torch.cat(self.H_indices, dim=1),
+ values=torch.cat(self.H_values, dim=0),
+ size=(self.n_items, self.n_items)).coalesce()
+ del self.H_indices, self.H_values
+
+ # Sparse approximation
+ if self.sparse_approx:
+ limit = (self.n_users + self.n_items) * 64 # Embedding size in MF models
+ thres = 1e-4
+ while (H_mat._nnz() + H_mat.indices().shape[-1] + self.n_items + 1) >= limit:
+ mask = torch.where(H_mat.values() > thres)[0]
+ H_mat = torch.sparse_coo_tensor(indices=H_mat.indices().index_select(-1, mask),
+ values=H_mat.values().index_select(-1, mask),
+ size=(self.n_items, self.n_items)).coalesce()
+ thres *= 1.25
+ self.H_mat = H_mat.T.to_sparse_csr()
+
+ def _inner_prod(self, A_mat: torch.Tensor, B_mat: torch.Tensor):
+ r"""
+ Small-batch inner product
+ """
+ assert A_mat.shape[-2] == B_mat.shape[-2]
+ result = torch.zeros((A_mat.shape[-1], B_mat.shape[-1]), device=self.device)
+ for chunk in torch.split(torch.arange(0, A_mat.shape[-2], device=self.device), 10000):
+ result += A_mat.index_select(dim=-2, index=chunk).to_dense().T @ \
+ B_mat.index_select(dim=-2, index=chunk).to_dense()
+ return result
+
+ def update_H_part(self, ilist):
+ r"""
+ Learning H in each partition (if any)
+ """
+ R_mat = self.inter_mat.index_select(dim=1, index=ilist) * self.DI(-self.norm_di, ilist)
+
+ H_aux = (0.5 / self.lambda_) * self._inner_prod(R_mat, R_mat)
+ II_mat = self._inner_prod(R_mat, self.u_norm * R_mat)
+ del R_mat
+
+ V_mat = self.V_mat[ilist, :]
+ diag_vvt = torch.square(V_mat).sum(dim=1).view(-1)
+
+ G_mat = - self.DI(self.norm_di - 1, ilist).T * \
+ (V_mat @ V_mat.T).clip(0).fill_diagonal_(0) * self.DI(1 - self.norm_di, ilist)
+ del V_mat, diag_vvt
+
+ H_aux = H_aux + (
+ self.eps * ((1 / self.lambda_) - 1) *
+ (II_mat @ G_mat)
+ )
+ del G_mat
+
+ II_inv = torch.inverse(
+ self.eps * II_mat + torch.diag(
+ self.DI(2, ilist).view(-1) * self.theta + self.rho
+ )
+ )
+ del II_mat
+
+ H_aux = II_inv @ H_aux
+ Phi_mat = torch.zeros_like(H_aux, device=self.device)
+ S_mat = torch.zeros_like(H_aux, device=self.device)
+
+ for _ in range(self.admm_iter):
+ # ADMM Iteration
+ H_tilde = H_aux + II_inv @ (self.rho * (S_mat - Phi_mat))
+ lag_op = torch.diag(H_tilde) / (torch.diag(II_inv) + 1e-10)
+ H_mat = H_tilde - II_inv * lag_op # Update H
+ S_mat = H_mat + Phi_mat
+ S_mat = torch.clip((S_mat.T + S_mat) / 2, min=0) # Update S
+ Phi_mat += H_mat - S_mat # Update Phi
+
+ return torch.triu(S_mat)
+
+ def forward(self):
+ r"""
+ Abstract method of GeneralRecommender in RecBole (not used)
+ """
+ pass
+
+ def calculate_loss(self, interaction):
+ r"""
+ Abstract method of GeneralRecommender in RecBole (not used)
+ """
+ return torch.nn.Parameter(torch.zeros(1))
+
+ def predict(self, interaction):
+ r"""
+ Abstract method of GeneralRecommender in RecBole (not used)
+ """
+ raise NotImplementedError
+
+ def full_sort_predict(self, interaction):
+ r"""
+ Recommend items for the input users
+ """
+ R_mat = self.inter_mat.index_select(dim=0, index=interaction[self.USER_ID]).to_dense()
+
+ Y_mat = R_mat * self.DI(-1) @ self.V_mat @ self.V_mat.T * self.di_sqr
+ Y_mat = ((1 / self.lambda_) - 1) * \
+ torch.clip(Y_mat - R_mat * torch.square(self.V_mat).sum(dim=1).reshape(1, -1), min=0)
+
+ R_mat = R_mat * self.DI(-self.norm_di)
+ Y_mat += ((self.H_mat @ R_mat.T).T + R_mat @ self.H_mat) * self.DI(self.norm_di)
+
+ return Y_mat
diff --git a/CoRML/utils.py b/CoRML/utils.py
new file mode 100644
index 0000000..11bf1f7
--- /dev/null
+++ b/CoRML/utils.py
@@ -0,0 +1,125 @@
+r"""
+ Name: utils.py
+ Date: 2023/04/12
+ Description: Utility functions.
+"""
+
+import os
+import shutil
+import warnings
+from logging import getLogger
+from datetime import datetime
+from recbole.config import Config
+from recbole.trainer import Trainer
+from recbole.data.dataset import Dataset
+from recbole.data import data_preparation
+from recbole.utils import init_seed
+from recbole.utils.enum_type import ModelType
+
+from CoRML.model import CoRML
+
+warnings.filterwarnings('ignore')
+
+class RecTrainer(object):
+ r"""
+ Trainer class.
+ """
+ def __init__(self, dataset='yelp2018'):
+ self.model_name = CoRML
+ self.dataset = dataset
+ self.dataset_file_list = ['Params/Overall.yaml', 'Params/{:s}.yaml'.format(dataset)]
+ self.preprocessing()
+
+ def preprocessing(self):
+ self.config = Config(model=self.model_name, config_dict=None, config_file_list=self.dataset_file_list)
+ init_seed(self.config['seed'], self.config['reproducibility'])
+ init_logger(self.config)
+ self.logger = getLogger()
+ dataset = Dataset(self.config)
+ self.train_data, self.valid_data, self.test_data = data_preparation(self.config, dataset)
+
+ def train(self, verbose=False):
+ self.logger.info("Start training:")
+ model = self.model_name(self.config, self.train_data.dataset).to(self.config['device'])
+ self.trainer = Trainer(self.config, model)
+ shutil.rmtree(self.trainer.tensorboard.log_dir)
+ cur_time = datetime.now().strftime('%b-%d-%Y_%H-%M-%S')
+ self.trainer.tensorboard = get_tensorboard(cur_time, self.config['model'])
+ _, _ = self.trainer.fit(self.train_data, None, verbose=verbose, show_progress=False)
+ load_best = False if self.trainer.model.type == ModelType.TRADITIONAL else True
+ self.logger.info("Start evalutaion:")
+ test_result = self.trainer.evaluate(self.test_data, load_best_model=load_best)
+ os.remove(self.trainer.saved_model_file)
+ shutil.rmtree(os.path.join('Log/{:s}'.format(self.config['model']), cur_time))
+ if os.path.exists('log_tensorboard'):
+ shutil.rmtree(os.path.join('log_tensorboard'))
+ for k, v in test_result.items():
+ self.logger.info("{:12s}: {:.6f}".format(k, v))
+
+from torch.utils.tensorboard import SummaryWriter
+from recbole.utils import ensure_dir
+
+def get_tensorboard(cur_time, model):
+ r"""
+ Modified version of get_tensorboard in Recbole.
+ Source: https://github.com/RUCAIBox/RecBole/blob/5d7df69bcbe9d21b4185946e8ee9a4bd8f041b9d/recbole/utils/utils.py#L206-L230
+ """
+ base_path = 'Log/{:s}'.format(model)
+ if not os.path.exists(base_path):
+ os.makedirs(base_path)
+ dir_name = cur_time
+ dir_path = os.path.join(base_path, dir_name)
+ writer = SummaryWriter(dir_path)
+ return writer
+
+import logging
+import colorlog
+from colorama import init
+from logging import getLogger
+from recbole.utils.logger import log_colors_config, RemoveColorFilter
+
+def init_logger(config):
+ r"""
+ Modified version of init_logger in Recbole.
+ Source: https://github.com/RUCAIBox/RecBole/blob/5d7df69bcbe9d21b4185946e8ee9a4bd8f041b9d/recbole/utils/logger.py#L60-L118
+ """
+ init(autoreset=True)
+ LOGROOT = './Log/'
+ dir_name = os.path.dirname(LOGROOT)
+ ensure_dir(dir_name)
+
+ logfilename = 'RUN.log'
+
+ logfilepath = os.path.join(LOGROOT, logfilename)
+
+ filefmt = "%(asctime)-15s %(levelname)s %(message)s"
+ filedatefmt = "%a %d %b %Y %H:%M:%S"
+ fileformatter = logging.Formatter(filefmt, filedatefmt)
+
+ sfmt = "%(log_color)s%(asctime)-15s %(levelname)s %(message)s"
+ sdatefmt = "%d %b %H:%M"
+ sformatter = colorlog.ColoredFormatter(sfmt, sdatefmt, log_colors=log_colors_config)
+ if config['state'] is None or config['state'].lower() == 'info':
+ level = logging.INFO
+ elif config['state'].lower() == 'debug':
+ level = logging.DEBUG
+ elif config['state'].lower() == 'error':
+ level = logging.ERROR
+ elif config['state'].lower() == 'warning':
+ level = logging.WARNING
+ elif config['state'].lower() == 'critical':
+ level = logging.CRITICAL
+ else:
+ level = logging.INFO
+
+ fh = logging.FileHandler(logfilepath)
+ fh.setLevel(level)
+ fh.setFormatter(fileformatter)
+ remove_color_filter = RemoveColorFilter()
+ fh.addFilter(remove_color_filter)
+
+ sh = logging.StreamHandler()
+ sh.setLevel(level)
+ sh.setFormatter(sformatter)
+
+ logging.basicConfig(level=level, handlers=[sh, fh])
\ No newline at end of file
diff --git a/Params/Overall.yaml b/Params/Overall.yaml
new file mode 100644
index 0000000..59f50b3
--- /dev/null
+++ b/Params/Overall.yaml
@@ -0,0 +1,48 @@
+# General Parameters
+seed: 2026 #Choose in [2022,2024,2026,2028,2030]
+reproducibility: True
+
+# Training Parameters
+train_neg_sample_args:
+ sample_num: 1
+epochs: 1
+train_batch_size: 524288
+
+# Evaluation Parameters
+eval_args:
+ split: {'RS':[0.6, 0.2, 0.2]}
+ order: RO
+ group_by: user
+ mode: full
+eval_batch_size: 1048576
+eval_step: 0
+stopping_step: 1
+valid_metric: NDCG@20
+topk: [5, 10, 20]
+metrics: ['NDCG', 'MRR']
+metric_decimal_place: 6
+
+# Dataset Default Parameters
+data_path: ./Data/
+load_col:
+ inter: [user_id, item_id]
+ user: [user_id]
+ item: [item_id]
+
+USER_ID_FIELD: user_id
+ITEM_ID_FIELD: item_id
+filter_inter_by_user_or_item: False
+
+# Model Default Parameters
+model: CoRML
+partition_ratio: 1.0 # Maximum size ratio of item partition (1.0 means no partitioning)
+sparse_approx: True # Sparse approximation to reduce storage size of H to compare with MF-based models
+eigenvectors: 256 # Number of eigenvectors used in SVD
+admm_iter: 50 # Number of iterations for ADMM
+
+lambda: 0.75 # Weights for H and G in preference scores
+dual_step_length: 5e3 # Dual step length of ADMM
+l2_regularization: 1.0 # L2-regularization for learning weight matrix G
+item_degree_norm: 0.1 # Item degree norm for learning weight matrix G
+global_scaling: 0.0 # Global scaling in approximated ranking weights (in logarithm scale)
+user_scaling: 0.0 # User degree scaling in approximated ranking weights
diff --git a/Params/Pinterest.yaml b/Params/Pinterest.yaml
new file mode 100644
index 0000000..2b227bf
--- /dev/null
+++ b/Params/Pinterest.yaml
@@ -0,0 +1,10 @@
+# Pinterest
+dataset: Pinterest
+
+lambda: 0.7
+dual_step_length: 5e3
+l2_regularization: 1e-3
+global_scaling: -0.5
+user_scaling: 0.5
+
+partition_ratio: 1.0
\ No newline at end of file
diff --git a/Params/gowalla.yaml b/Params/gowalla.yaml
new file mode 100644
index 0000000..7637ceb
--- /dev/null
+++ b/Params/gowalla.yaml
@@ -0,0 +1,10 @@
+# Gowalla
+dataset: gowalla
+
+lambda: 0.8
+dual_step_length: 5e3
+l2_regularization: 1.0
+global_scaling: 0.0
+user_scaling: 0.0
+
+partition_ratio: 0.35
\ No newline at end of file
diff --git a/Params/ml-20m.yaml b/Params/ml-20m.yaml
new file mode 100644
index 0000000..92cd690
--- /dev/null
+++ b/Params/ml-20m.yaml
@@ -0,0 +1,10 @@
+# ML-20M
+dataset: ml-20m
+
+lambda: 0.8
+dual_step_length: 10.0
+l2_regularization: 0.1
+global_scaling: 0.0
+user_scaling: 0.0
+
+partition_ratio: 1.0
\ No newline at end of file
diff --git a/Params/yelp2018.yaml b/Params/yelp2018.yaml
new file mode 100644
index 0000000..b27978a
--- /dev/null
+++ b/Params/yelp2018.yaml
@@ -0,0 +1,10 @@
+# Yelp2018
+dataset: yelp2018
+
+lambda: 0.8
+dual_step_length: 5e3
+l2_regularization: 1.0
+global_scaling: -1.0
+user_scaling: 0.5
+
+partition_ratio: 0.3
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..110ff62
--- /dev/null
+++ b/README.md
@@ -0,0 +1,114 @@
+# CoRML
+[![RecBole](https://img.shields.io/badge/RecBole-1.1.1-orange)](https://recbole.io/)
+[![arXiv](https://img.shields.io/badge/arXiv-2304.07971-red)](https://arxiv.org/abs/2304.07971)
+[![License](https://img.shields.io/github/license/Joinn99/CoRML)](https://github.com/Joinn99/CoRML/blob/torch/LICENSE.md)
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Joinn99/CoRML/blob/torch/CoRML%20Demonstration.ipynb)
+
+**PyTorch version** (Default) | [**CuPy version**](https://github.com/Joinn99/CoRML/tree/cupy)
+
+This is the official implementation of our *The 46th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR 2023)* paper:
+> Tianjun Wei, Jianghong Ma, Tommy W.S. Chow. Collaborative Residual Metric Learning. [[arXiv](https://arxiv.org/abs/2304.07971)]
+
+
+
+
+
+
+## Requirements
+The model implementation ensures compatibility with the Recommendation Toolbox [RecBole](https://recbole.io/) (Github: [Recbole](https://github.com/RUCAIBox/RecBole)).
+
+The requirements of the running environement:
+
+- Python: 3.8+
+- PyTorch: 1.9.0+
+- RecBole: 1.1.1
+
+## Dataset
+Here we only put zip files of datasets in the respository due to the storage limits. To use the dataset, run
+```bash
+unzip -o "Data/*.zip"
+```
+to unzip the dataset files.
+
+If you like to test CoRML on the custom dataset, please place the dataset files in the following path:
+```bash
+.
+|-Data
+| |-[CUSTOM_DATASET_NAME]
+| | |-[CUSTOM_DATASET_NAME].user
+| | |-[CUSTOM_DATASET_NAME].item
+| | |-[CUSTOM_DATASET_NAME].inter
+
+```
+And create `[CUSTOM_DATASET_NAME].yaml` in `./Params` with the following content:
+```yaml
+dataset: [CUSTOM_DATASET_NAME]
+```
+
+For the format of each dataset file, please refer to [RecBole API](https://recbole.io/docs/user_guide/data/atomic_files.html).
+
+## Hyperparameter
+For each dataset, the optimal hyperparameters are stored in `Params/[DATASET].yaml`. To tune the hyperparamters, modify the corresponding values in the file for each dataset.
+
+The main hyparameters of CoRML are listed as follows:
+ - `lambda` ($\lambda$): Weights for $\mathbf{H}$ and $\mathbf{G}$ in preference scores
+ - `dual_step_length` ($\rho$): Dual step length of ADMM
+ - `l2_regularization` ($\theta$): $L2-regularization for learning weight matrix $\mathbf{H}$
+ - `item_degree_norm` ($t$): Item degree norm for learning weight matrix $\mathbf{H}$
+ - `global_scaling` ($\epsilon$): Global scaling in approximated ranking weights (in logarithm scale)
+ - `user_scaling` ($t_u$): User degree scaling in approximated ranking weights
+
+### Running on small partitions
+For datasets containing a large number of items, calculating and storing the complete item-item matrix may lead to out-of-memory problem in environments with small GPU memory. Therefore, we have added the code for graph spectral partitioning to learn the item-item weight matrix on each small partitioned item set. The code was modified based on our previous work [FPSR](https://github.com/Joinn99/FPSR).
+
+Hyperparameter `partition_ratio` is used to control the maximum ratio of the partitioned set of items relative to the complete set, ranging from 0 to 1. When `partition_ratio` is set to 1, no partitioning will be performed.
+
+### Sparse approximation
+To maintain consistency, we perform a sparse approximation of the derived matrice $\mathbf{H}$ in CoRML by setting the entries to 0 where $\lvert \mathbf{H} \rvert \leq \eta$. The threshold $\eta$ will be adjusted so that the storage size of the sparse matrix $\mathbf{H}_{Sparse}$ is less than other types of models with embedding size 64.
+
+Hyperparameter `sparse_approx` is used to control the sparse approximation. When `sparse_approx` is set to `False`, no sparse approximation will be performed.
+
+## Running
+The script `run.py` is used to reproduced the results presented in paper. Train and evaluate CoRML on a specific dataset, run
+```bash
+python run.py --dataset DATASET_NAME
+```
+
+## Google Colab
+We also provide Colab notebook version of CoRML, you can click [here](https://colab.research.google.com/github/Joinn99/CoRML/blob/torch/CoRML%20Demonstration.ipynb) to open Google Colab, select the runtime type as *GPU*, and run the model.
+
+## Citation
+If you wish, please cite the following paper:
+
+```bibtex
+
+@InProceedings{CoRML,
+ author = {{Wei}, Tianjun and {Ma}, Jianghong and {Chow}, Tommy W.~S.},
+ booktitle = {Proceedings of the 46th International ACM SIGIR Conference on Research and Development in Information Retrieval},
+ title = {Collaborative Residual Metric Learning},
+ year = {2023},
+ address = {New York, NY, USA},
+ publisher = {Association for Computing Machinery},
+ series = {SIGIR '23},
+ doi = {10.1145/3539618.3591649},
+ location = {Taipei, Taiwan},
+ numpages = {10},
+ url = {https://doi.org/10.1145/3539618.3591649},
+}
+
+@InProceedings{FPSR,
+ author = {{Wei}, Tianjun and {Ma}, Jianghong and {Chow}, Tommy W.~S.},
+ booktitle = {Proceedings of the ACM Web Conference 2023},
+ title = {Fine-tuning Partition-aware Item Similarities for Efficient and Scalable Recommendation},
+ year = {2023},
+ address = {New York, NY, USA},
+ publisher = {Association for Computing Machinery},
+ series = {WWW '23},
+ doi = {10.1145/3543507.3583240},
+ location = {Austin, TX, USA},
+ numpages = {11},
+ url = {https://doi.org/10.1145/3543507.3583240},
+}
+```
+## License
+This project is licensed under the terms of the MIT license.
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..4da0868
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,2 @@
+Recbole==1.1.1
+torch>=1.9.0
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000..40cba5e
--- /dev/null
+++ b/run.py
@@ -0,0 +1,9 @@
+import argparse
+from CoRML.utils import RecTrainer
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset', '-d', type=str, default='yelp2018', help='Name of Datasets')
+ args, _ = parser.parse_known_args()
+
+ RecTrainer(dataset=args.dataset).train(verbose=False)
\ No newline at end of file