Skip to content

Commit

Permalink
Merge branch 'clingingsai-gss_local' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lazishu2000 committed Jul 19, 2023
2 parents b440998 + eb8dbd0 commit e57edde
Show file tree
Hide file tree
Showing 11 changed files with 750 additions and 5 deletions.
12 changes: 11 additions & 1 deletion openhgnn/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ n_neighbor = 8
aggregate = SUM
n_relation = 60
n_user = 1872
epoch_iter = 100
# epoch_iter = 100
max_epoch = 100
mini_batch_flag = True

[HeGAN]
Expand Down Expand Up @@ -877,3 +878,12 @@ ssl_beta = 0.32
rank = 3
Layers = 2
reg = 0.043
[lightGCN]
lr = 0.001
weight_decay = 0.0001
max_epoch = 1000
batch_size = 1024
embedding_size = 64
num_layers = 3
test_u_batch_size = 100
topks = 20
16 changes: 14 additions & 2 deletions openhgnn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def __init__(self, file_path, model, dataset, task, gpu):
self.aggregate = conf.get("KGCN", "aggregate")
self.n_item = conf.getint("KGCN", "n_relation")
self.n_user = conf.getint("KGCN", "n_user")
self.epoch_iter = conf.getint("KGCN", "epoch_iter")
# self.epoch_iter = conf.getint("KGCN", "epoch_iter")
self.max_epoch = conf.getint("KGCN", "max_epoch")

elif self.model_name == 'general_HGNN':
self.lr = conf.getfloat("general_HGNN", "lr")
Expand Down Expand Up @@ -803,7 +804,7 @@ def __init__(self, file_path, model, dataset, task, gpu):
# self.use_norm = conf.get("DiffMG", "use_norm")
# self.out_nl = conf.get("DiffMG", "out_nl")

elif model == 'MeiREC':
elif self.model_name == 'MeiREC':
self.lr = conf.getfloat("MeiREC", "lr")
self.weight_decay = conf.getfloat("MeiREC", "weight_decay")
self.vocab = conf.getint("MeiREC", "vocab_size")
Expand Down Expand Up @@ -909,6 +910,17 @@ def __init__(self, file_path, model, dataset, task, gpu):
self.rank = conf.getint("HGCL", "rank")
self.Layers = conf.getint("HGCL", "Layers")

elif self.model_name == 'lightGCN':
self.lr = conf.getfloat("lightGCN", "lr")
self.weight_decay = conf.getfloat("lightGCN", "weight_decay")
self.max_epoch = conf.getint("lightGCN", "max_epoch")
self.batch_size = conf.getint("lightGCN", "batch_size")
self.embedding_size = conf.getint("lightGCN", "embedding_size")
self.num_layers = conf.getint("lightGCN", "num_layers")
self.test_u_batch_size = conf.getint("lightGCN", "test_u_batch_size")
self.topks = conf.getint("lightGCN", "topks")
# self.alpha = conf.getfloat("lightGCN", "alpha")

if hasattr(self, 'device'):
self.device = th.device(self.device)
elif gpu == -1:
Expand Down
178 changes: 178 additions & 0 deletions openhgnn/dataset/RecommendationDataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import dgl
import torch as th
import numpy as np
from . import BaseDataset, register_dataset
from dgl.data.utils import download,load_graphs,save_graphs,save_info,load_info
from scipy.sparse import csr_matrix
import scipy.sparse as sp
from .multigraph import MultiGraphDataset
from ..sampler.negative_sampler import Uniform_exclusive
from . import AcademicDataset
Expand Down Expand Up @@ -72,9 +75,184 @@ def get_split(self, validation=True):
def get_train_data(self):
pass

<<<<<<< HEAD
def get_labels(self):
return self.label

=======
@register_dataset('lightGCN_recommendation')
class lightGCN_Recommendation(RecommendationDataset):

def __init__(self, dataset_name, *args, **kwargs):
super(RecommendationDataset, self).__init__(*args, **kwargs)

if dataset_name not in ['gowalla','yelp2018','amazon-book']:
raise KeyError('Dataset {} is not supported!'.format(dataset_name))
self.dataset_name=dataset_name

self.data_path=f'openhgnn/dataset/{self.dataset_name}'

if not os.path.exists(f"{self.data_path}/train.txt"):
self.download()

# test
self.mode_dict = {'train': 0, "test": 1}
self.mode = self.mode_dict['train']
self.n_user = 0
self.m_item = 0
path = './openhgnn/dataset/' + dataset_name
train_file = path + '/train.txt'
test_file = path + '/test.txt'
self.path = path
trainUniqueUsers, trainItem, trainUser = [], [], []
testUniqueUsers, testItem, testUser = [], [], []
self.traindataSize = 0
self.testDataSize = 0

with open(train_file) as f:
for l in f.readlines():
if len(l) > 0:
l = l.strip('\n').split(' ')
items = [int(i) for i in l[1:]]
uid = int(l[0])
trainUniqueUsers.append(uid)
trainUser.extend([uid] * len(items))
trainItem.extend(items)

self.m_item = max(self.m_item, max(items))
self.n_user = max(self.n_user, uid)
self.traindataSize += len(items)
self.trainUniqueUsers = np.array(trainUniqueUsers)
self.trainUser = np.array(trainUser)
self.trainItem = np.array(trainItem)

with open(test_file) as f:
for l in f.readlines():
if len(l) > 0:
l = l.strip('\n').split(' ')
items = [int(i) for i in l[1:]]
uid = int(l[0])
testUniqueUsers.append(uid)
testUser.extend([uid] * len(items))
testItem.extend(items)
self.m_item = max(self.m_item, max(items))
self.n_user = max(self.n_user, uid)
self.testDataSize += len(items)
self.m_item += 1
self.n_user += 1
self.testUniqueUsers = np.array(testUniqueUsers)
self.testUser = np.array(testUser)
self.testItem = np.array(testItem)

self.Graph = None

# (users,items), bipartite graph
self.UserItemNet = csr_matrix((np.ones(len(self.trainUser)), (self.trainUser, self.trainItem)),
shape=(self.n_user, self.m_item))
self.users_D = np.array(self.UserItemNet.sum(axis=1)).squeeze()
self.users_D[self.users_D == 0.] = 1
self.items_D = np.array(self.UserItemNet.sum(axis=0)).squeeze()
self.items_D[self.items_D == 0.] = 1.
# pre-calculate
self.allPos = self.getUserPosItems(list(range(self.n_user)))
self.testDict = self.__build_test()

self.g = self.getSparseGraph()

def get_split(self):
return self.g, [], []

def __build_test(self):
"""
return:
dict: {user: [items]}
"""
test_data = {}
for i, item in enumerate(self.testItem):
user = self.testUser[i]
if test_data.get(user):
test_data[user].append(item)
else:
test_data[user] = [item]
return test_data

def getUserPosItems(self, users):
posItems = []
for user in users:
posItems.append(self.UserItemNet[user].nonzero()[1])
return posItems

def _convert_sp_mat_to_sp_tensor(self, X):
coo = X.tocoo().astype(np.float32)
row = th.Tensor(coo.row).long()
col = th.Tensor(coo.col).long()
index = th.stack([row, col])
data = th.FloatTensor(coo.data)
return th.sparse.FloatTensor(index, data, th.Size(coo.shape))

def getSparseGraph(self):
print("loading adjacency matrix")
if self.Graph is None:
try:
pre_adj_mat = sp.load_npz(self.path + '/s_pre_adj_mat.npz')
print("successfully loaded...")
norm_adj = pre_adj_mat
except:
print("generating adjacency matrix")
# s = time()
adj_mat = sp.dok_matrix((self.n_user + self.m_item, self.n_user + self.m_item), dtype=np.float32)
adj_mat = adj_mat.tolil()
R = self.UserItemNet.tolil()
adj_mat[:self.n_user, self.n_user:] = R
adj_mat[self.n_user:, :self.n_user] = R.T
adj_mat = adj_mat.todok()
# adj_mat = adj_mat + sp.eye(adj_mat.shape[0])

rowsum = np.array(adj_mat.sum(axis=1))
d_inv = np.power(rowsum, -0.5).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat = sp.diags(d_inv)

norm_adj = d_mat.dot(adj_mat)
norm_adj = norm_adj.dot(d_mat)
norm_adj = norm_adj.tocsr()
# end = time()
# print(f"costing {end - s}s, saved norm_mat...")
sp.save_npz(self.path + '/s_pre_adj_mat.npz', norm_adj)

# if self.split == True:
# self.Graph = self._split_A_hat(norm_adj)
# print("done split matrix")
# else:
self.Graph = self._convert_sp_mat_to_sp_tensor(norm_adj)
# self.Graph = self.Graph.coalesce().to(self.device)
self.Graph = self.Graph.coalesce()
print("don't split the matrix")
return self.Graph

def download(self):
prefix = 'https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data'

required_file = ['train.txt', 'test.txt']

for filename in required_file:
url = f"{prefix}/{self.dataset_name}/{filename}"
file_path = f"{self.data_path}/{filename}"
if not os.path.exists(file_path):
try:
download(url, file_path)

except BaseException as e:
print("\n",e)
print("\nNote! --- If you want to download the file, vpn is required ---")
print("If you don't have a vpn, please download the dataset from here: https://github.com/gusye1234/LightGCN-PyTorch")
print("\nAfter downloading the dataset, you need to store the files in the following path: ")
print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\train.txt")
print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\test.txt")
exit()


>>>>>>> 6a501a7d28f6ee992cc1de3a1e28ba5289513609
@register_dataset('hin_recommendation')
class HINRecommendation(RecommendationDataset):
def __init__(self, dataset_name, *args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions openhgnn/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def build_dataset(dataset, task, *args, **kwargs):
_dataset = 'kg_link_prediction'
elif dataset in ['LastFM4KGCN']:
_dataset = 'kgcn_recommendation'
elif dataset in ['gowalla', 'yelp2018', 'amazon-book']:
_dataset = 'lightGCN_recommendation'
elif dataset in ['yelp4rec']:
_dataset = 'hin_' + task
elif dataset in ['Epinions', 'CiaoDVD', 'Yelp']:
Expand Down
1 change: 1 addition & 0 deletions openhgnn/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Experiment(object):
'KGAT': 'KGAT_trainer'
'SHGP': 'SHGP_trainer'
'HGCL': 'hgcltrainer',
'lightGCN': 'lightGCN_trainer',
}
immutable_params = ['model', 'dataset', 'task']

Expand Down
9 changes: 9 additions & 0 deletions openhgnn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def build_model_from_args(args, hg):
'SHGP': 'openhgnn.models.ATT_HGCN',
'DSSL': 'openhgnn.models.DSSL'
'HGCL': 'openhgnn.models.HGCL'
'lightGCN': 'openhgnn.models.lightGCN',
}

from .HGCL import HGCL
Expand Down Expand Up @@ -149,8 +150,12 @@ def build_model_from_args(args, hg):
from .DiffMG import DiffMG
from .MeiREC import MeiREC
from .HGNN_AC import HGNN_AC
<<<<<<< HEAD
from .KGAT import KGAT
from .DSSL import DSSL
=======
from .lightGCN import lightGCN
>>>>>>> 6a501a7d28f6ee992cc1de3a1e28ba5289513609

__all__ = [
'BaseModel',
Expand Down Expand Up @@ -186,6 +191,7 @@ def build_model_from_args(args, hg):
'DiffMG',
'MeiREC',
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
'KGAT'
=======
Expand All @@ -195,5 +201,8 @@ def build_model_from_args(args, hg):
'KGAT',
'DSSL'
>>>>>>> e1d95c140eccfdb60975128ed8168dfc5ca6ec1f
=======
'lightGCN',
>>>>>>> 6a501a7d28f6ee992cc1de3a1e28ba5289513609
]
classes = __all__
Loading

0 comments on commit e57edde

Please sign in to comment.