-
Notifications
You must be signed in to change notification settings - Fork 3
/
main_simple.py
78 lines (66 loc) · 3.31 KB
/
main_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from trainer_simple import Trainer
from tester import Tester
from dataset import Dataset
import argparse
import time
def get_parameter():
parser = argparse.ArgumentParser()
parser.add_argument('-ne', default=1000, type=int, help="number of epochs")
parser.add_argument('-D_lr', default=0.1, type=float, help="discriminator learning rate")
parser.add_argument('-G_lr', default=0.001, type=float, help="generator learning rate")
parser.add_argument('-reg_lambda', default=0.03, type=float, help="l2 regularization parameter")
parser.add_argument('-dataset', default="Deepddi", type=str, help="wordnet dataset")
parser.add_argument('-emb_dim', default=200, type=int, help="embedding dimension")
parser.add_argument('-neg_ratio', default=1, type=int, help="number of negative examples per positive example")
parser.add_argument('-batch_size', default=512, type=int, help="batch size")
parser.add_argument('-save_each', default=100, type=int, help="validate every k epochs")
parser.add_argument('-discriminator_range', default=1, type=int, help="discriminator_range")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_parameter()
dataset = Dataset(args.dataset)
print("~~~~ Training ~~~~")
print('aae_' + 'simple')
trainer = Trainer(dataset, args)
trainer.train()
print("~~~~ Testing on the 100 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '100' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 200 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '200' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 300 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '300' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 400 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '400' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 500 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '500' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 600 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '600' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 700 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '700' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 800 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '800' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 900 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '900' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()
print("~~~~ Testing on the 1000 epoch ~~~~")
model_path = "models/" + args.dataset + "/" + 'simple' + "/" + '1000' + ".chkpnt"
tester = Tester(dataset, model_path, "test")
tester.test()