-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathworld.py
91 lines (75 loc) · 2.75 KB
/
world.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
'''
Created on Mar 1, 2020
Pytorch Implementation of LightGCN in
Xiangnan He et al. LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation
@author: Jianbai Ye ([email protected])
'''
import os
from os.path import join
import torch
from enum import Enum
from parse import parse_args
import multiprocessing
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
args = parse_args()
ROOT_PATH = "/home/cx/PycharmProjects/LightGCN-PyTorch-master"
CODE_PATH = join(ROOT_PATH, 'code')
DATA_PATH = join(ROOT_PATH, 'data')
BOARD_PATH = join(CODE_PATH, 'runs')
FILE_PATH = join(CODE_PATH, 'checkpoints')
import sys
sys.path.append(join(CODE_PATH, 'sources'))
if not os.path.exists(FILE_PATH):
os.makedirs(FILE_PATH, exist_ok=True)
config = {}
all_dataset = ['lastfm', 'gowalla', 'yelp2018', 'amazon-book', 'yelp_multiclass', 'ml1m-2', 'ml1m-4', 'amazon',\
'my-amazon-2', 'amazon-2', 'amazon-4', 'yelp-2', 'yelp-4', 'amazon_book']
all_models = ['mf', 'lgn']
# config['batch_size'] = 4096
config['bpr_batch_size'] = args.bpr_batch
config['latent_dim_rec'] = args.recdim
config['method'] = args.method
config['lightGCN_n_layers']= args.layer
config['dropout'] = args.dropout
config['keep_prob'] = args.keepprob
config['A_n_fold'] = args.a_fold
config['test_u_batch_size'] = args.testbatch
config['multicore'] = args.multicore
config['lr'] = args.lr
config['decay'] = args.decay
config['pretrain'] = args.pretrain
config['A_split'] = False
config['bigdata'] = False
GPU = torch.cuda.is_available()
device = torch.device('cuda' if GPU else "cpu")
CORES = multiprocessing.cpu_count() // 2
seed = args.seed
dataset = args.dataset
model_name = args.model
method = args.method
if dataset not in all_dataset:
raise NotImplementedError(f"Haven't supported {dataset} yet!, try {all_dataset}")
if model_name not in all_models:
raise NotImplementedError(f"Haven't supported {model_name} yet!, try {all_models}")
TRAIN_epochs = args.epochs
LOAD = args.load
PATH = args.path
topks = eval(args.topks)
tensorboard = args.tensorboard
comment = args.comment
# let pandas shut up
from warnings import simplefilter
simplefilter(action="ignore", category=FutureWarning)
def cprint(words : str):
print(f"\033[0;30;43m{words}\033[0m")
logo = r"""
██╗ ██████╗ ███╗ ██╗
██║ ██╔════╝ ████╗ ██║
██║ ██║ ███╗██╔██╗ ██║
██║ ██║ ██║██║╚██╗██║
███████╗╚██████╔╝██║ ╚████║
╚══════╝ ╚═════╝ ╚═╝ ╚═══╝
"""
# font: ANSI Shadow
# refer to http://patorjk.com/software/taag/#p=display&f=ANSI%20Shadow&t=Sampling
# print(logo)