forked from ssu-dmlab/DGRec-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
34 lines (27 loc) · 845 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import fire
import sys
import json
from dotmap import DotMap
from src.main_trainer import main
os.chdir('./src')
def main_wrapper(data_name='bookdata'):
param_path = f'../hyperparameter/{data_name}/param.json'
with open(param_path, 'r') as in_file:
param = DotMap(json.load(in_file))
main(model=param.model,
data_name=param.data_name,
seed=param.seed,
epochs=param.epochs,
act=param.act,
batch_size=param.batch_size,
learning_rate=param.learning_rate,
embedding_size=param.embedding_size,
max_length=param.max_length,
samples_1=param.samples_1,
samples_2=param.samples_2,
dropout=param.dropout,
decay_rate=param.decay_rate,
)
if __name__ == "__main__":
sys.exit(fire.Fire(main_wrapper))