Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] ALS Matrix Factorization using External Library (implicit) #2124

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/recbole/recbole.model.general_recommender.als.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: recbole.model.general_recommender.als
:members:
:undoc-members:
:show-inheritance:
86 changes: 86 additions & 0 deletions docs/source/user_guide/model/general/als.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
ALS(External algorithm library)
===========

Introduction
---------------------

`[ALS (implicit)] <https://benfred.github.io/implicit/api/models/cpu/als.html>`_

**ALS (AlternatingLeastSquares)** by implicit is a Recommendation Model based on the algorithm proposed by Koren in `Collaborative Filtering for Implicit Feedback Datasets <http://yifanhu.net/PUB/cf.pdf>`_.
It furthermore leverages the finding out of `Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering <https://dl.acm.org/doi/pdf/10.1145/2043932.2043987>`_ for performance optimization.
`Implicit <https://benfred.github.io/implicit/index.html>`_ provides several models for implicit feedback recommendations.

`[paper] <http://yifanhu.net/PUB/cf.pdf>`_

**Title:** Collaborative Filtering for Implicit Feedback Datasets

**Authors:** Hu, Yifan and Koren, Yehuda and Volinsky, Chris

**Abstract:** A common task of recommender systems is to improve
customer experience through personalized recommendations based on prior implicit feedback. These systems passively track different sorts of user behavior, such as purchase history, watching habits and browsing activity, in order to model user preferences. Unlike the much more extensively researched explicit feedback, we do not have any
direct input from the users regarding their preferences. In
particular, we lack substantial evidence on which products
consumer dislike. In this work we identify unique properties of implicit feedback datasets. We propose treating the
data as indication of positive and negative preference associated with vastly varying confidence levels. This leads to a
factor model which is especially tailored for implicit feedback recommenders. We also suggest a scalable optimization procedure, which scales linearly with the data size. The
algorithm is used successfully within a recommender system
for television shows. It compares favorably with well tuned
implementations of other known methods. In addition, we
offer a novel way to give explanations to recommendations
given by this factor model.

Running with RecBole
-------------------------

**Model Hyper-Parameters:**

- ``embedding_size (int)`` : The number of latent factors to compute. Defaults to ``64``.
- ``regularization (float)`` : The regularization factor to use. Defaults to ``0.01``.
- ``alpha (float)`` : The weight to give to positive examples. Defaults to ``1.0``.

Please refer to [Implicit Python package](https://benfred.github.io/implicit/index.html) for more details.

**A Running Example:**

Write the following code to a python file, such as `run.py`

.. code:: python

from recbole.quick_start import run_recbole

run_recbole(model='ALS', dataset='ml-100k')

And then:

.. code:: bash

python run.py

Tuning Hyper Parameters
-------------------------

If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``.

.. code:: bash

regularization choice [0.01, 0.03, 0.05, 0.1]
embedding_size choice [32, 64, 96, 128, 256]
alpha choice [0.5, 0.7, 1.0, 1.3, 1.5]

Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model.

Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning:

.. code:: bash

python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test

For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`.


If you want to change parameters, dataset or evaluation settings, take a look at

- :doc:`../../../user_guide/config_settings`
- :doc:`../../../user_guide/data_intro`
- :doc:`../../../user_guide/train_eval_intro`
- :doc:`../../../user_guide/usage`
93 changes: 93 additions & 0 deletions recbole/model/general_recommender/als.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# @Time : 2024/12/01
# @Author : Markus Hoefling
# @Email : [email protected]

r"""
ALS
################################################
Reference 1:
Hu, Y., Koren, Y., & Volinsky, C. (2008). "Collaborative Filtering for Implicit Feedback Datasets." In ICDM 2008.

Reference 2:
Frederickson, Ben, "Implicit 0.7.2", code: https://github.com/benfred/implicit, readthedocs: https://benfred.github.io/implicit/
"""

import numpy as np
import threadpoolctl
import torch
from implicit.als import AlternatingLeastSquares
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.utils import InputType, ModelType
threadpoolctl.threadpool_limits(1, "blas") # Due to a warning that occurred while running the ALS algorithm

class ALS(GeneralRecommender):
r"""
ALS is a matrix factorization model implemented using the Alternating Least Squares (ALS) method
from the `implicit` library (https://benfred.github.io/implicit/).
This model optimizes the embeddings through the Alternating Least Squares algorithm.
"""

input_type = InputType.POINTWISE
type = ModelType.GENERAL

def __init__(self, config, dataset):
super(ALS, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.regularization = config['regularization']
self.alpha = config['alpha']
self.iterations = config['epochs']

# define model
self.model = AlternatingLeastSquares(
factors=self.embedding_size,
regularization=self.regularization,
alpha=self.alpha,
iterations=1, # iterations are done by the ALSTrainer via 'epochs'
use_cg=True,
calculate_training_loss=True,
num_threads=0,
random_state=42
)

# initialize embeddings
self.user_embeddings = np.random.rand(self.n_users, self.embedding_size)
self.item_embeddings = np.random.rand(self.n_items, self.embedding_size)

# fake embeddings for optimizer initialization
self.fake_parameter = torch.nn.Parameter(torch.zeros(1))

def get_user_embedding(self, user):
return torch.tensor(self.user_embeddings[user])

def get_item_embedding(self, item):
return torch.tensor(self.item_embeddings[item])

def forward(self, user, item):
user_e = self.get_user_embedding(user)
item_e = self.get_item_embedding(item)
return user_e, item_e

def _callback(self, iteration, time, loss):
self._loss = loss

def calculate_loss(self, interactions):
self.model.fit(interactions, show_progress=False, callback=self._callback)
self.user_embeddings = self.model.user_factors
self.item_embeddings = self.model.item_factors
return self._loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
user_e, item_e = self.forward(user, item)
return torch.dot(user_e, item_e)

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
user_e = self.get_user_embedding(user)
all_item_e = torch.tensor(self.model.item_factors)
score = torch.matmul(user_e, all_item_e.transpose(0, 1))
return score.view(-1)
3 changes: 3 additions & 0 deletions recbole/properties/model/ALS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
regularization: 0.01 # The number of latent factors to compute
embedding_size: 64 # The regularization factor to use
alpha: 1.0 # The weight to give to positive examples.
123 changes: 122 additions & 1 deletion recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn.utils.clip_grad import clip_grad_norm_
from tqdm import tqdm
import torch.cuda.amp as amp
import scipy.sparse as sp

from recbole.data.interaction import Interaction
from recbole.data.dataloader import FullSortEvalDataLoader
Expand Down Expand Up @@ -92,7 +93,6 @@ def sync_grad_loss(self):
sync_loss += torch.sum(params) * 0
return sync_loss


class Trainer(AbstractTrainer):
r"""The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common
functions for training and evaluation processes of most recommender system models, including fit(), evaluate(),
Expand Down Expand Up @@ -671,6 +671,127 @@ def _spilt_predict(self, interaction, batch_size):
result_list.append(result)
return torch.cat(result_list, dim=0)

class ALSTrainer(Trainer):
r"""ALSTrainer is designed for the ALS model of the implicit library: https://benfred.github.io/implicit"""

def __init__(self, config, model):
super(ALSTrainer, self).__init__(config, model)

def fit(
self,
train_data,
valid_data=None,
verbose=True,
saved=True,
show_progress=False,
callback_fn=None,
):
r"""Train the model based on the train data and the valid data.

Args:
train_data (DataLoader): the train data
valid_data (DataLoader, optional): the valid data, default: None.
If it's None, the early_stopping is invalid.
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
saved (bool, optional): whether to save the model parameters, default: True
show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
callback_fn (callable): Optional callback function executed at end of epoch.
Includes (epoch_idx, valid_score) input arguments.

Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
if saved and self.start_epoch >= self.epochs:
self._save_checkpoint(-1, verbose=verbose)

self.eval_collector.data_collect(train_data)
if self.config["train_neg_sample_args"].get("dynamic", False):
train_data.get_model(self.model)
valid_step = 0

for epoch_idx in range(self.start_epoch, self.epochs):
# train
training_start_time = time()
# pass entire dataset as sparse csr, as required in https://benfred.github.io/implicit
train_loss = self.model.calculate_loss(train_data._dataset.inter_matrix(form='csr'))
self.train_loss_dict[epoch_idx] = (
sum(train_loss) if isinstance(train_loss, tuple) else train_loss
)
training_end_time = time()
train_loss_output = self._generate_train_loss_output(
epoch_idx, training_start_time, training_end_time, train_loss
)
if verbose:
self.logger.info(train_loss_output)
self._add_train_loss_to_tensorboard(epoch_idx, train_loss)
self.wandblogger.log_metrics(
{"epoch": epoch_idx, "train_loss": train_loss, "train_step": epoch_idx},
head="train",
)

# eval
if self.eval_step <= 0 or not valid_data:
if saved:
self._save_checkpoint(epoch_idx, verbose=verbose)
continue
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(
valid_data, show_progress=show_progress
)

(
self.best_valid_score,
self.cur_step,
stop_flag,
update_flag,
) = early_stopping(
valid_score,
self.best_valid_score,
self.cur_step,
max_step=self.stopping_step,
bigger=self.valid_metric_bigger,
)
valid_end_time = time()
valid_score_output = (
set_color("epoch %d evaluating", "green")
+ " ["
+ set_color("time", "blue")
+ ": %.2fs, "
+ set_color("valid_score", "blue")
+ ": %f]"
) % (epoch_idx, valid_end_time - valid_start_time, valid_score)
valid_result_output = (
set_color("valid result", "blue") + ": \n" + dict2str(valid_result)
)
if verbose:
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)
self.tensorboard.add_scalar("Vaild_score", valid_score, epoch_idx)
self.wandblogger.log_metrics(
{**valid_result, "valid_step": valid_step}, head="valid"
)

if update_flag:
if saved:
self._save_checkpoint(epoch_idx, verbose=verbose)
self.best_valid_result = valid_result

if callback_fn:
callback_fn(epoch_idx, valid_score)

if stop_flag:
stop_output = "Finished training, best eval result in epoch %d" % (
epoch_idx - self.cur_step * self.eval_step
)
if verbose:
self.logger.info(stop_output)
break

valid_step += 1

self._add_hparam_to_tensorboard(self.best_valid_score)
return self.best_valid_score, self.best_valid_result

class KGTrainer(Trainer):
r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the
Expand Down