-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_graph.py
97 lines (82 loc) · 3.12 KB
/
train_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import torch
import torch.nn as nn
import numpy as np
import torch_geometric.transforms as T
from datasets.load_datasets import get_dataset, get_dataloader
from models import GCN
def evaluate(dataloader, model, loss_fc):
acc = []
loss_list = []
model.eval()
with torch.no_grad():
for data in dataloader:
logit = model(data)
loss = loss_fc(logit, data.y)
prediction = torch.argmax(logit, -1)
loss_list.append(loss.item())
acc.append((prediction == data.y).numpy())
return np.concatenate(acc, axis=0).mean(), np.average(loss_list)
if __name__ == '__main__':
with open("configs.json") as config_file:
configs = json.load(config_file)
dataset_name = configs.get("dataset_name").get("graph")
epochs = 5000
pooling = {'mutagenicity': ['max', 'mean', 'sum'],
'ba_2motifs': ['max'],
'bbbp': ['max', 'mean', 'sum']}
early_stop = 100
loop = True
if dataset_name == 'ba_2motifs':
loop = False
normalize = T.NormalizeFeatures()
dataset = get_dataset(dataset_dir="./datasets", dataset_name=dataset_name)
dataset.data.x = dataset.data.x.float()
dataset.data = normalize(dataset.data)
data_loader = get_dataloader(dataset, batch_size=32, random_split_flag=True,
data_split_ratio=[0.8, 0.1, 0.1], seed=2)
model = GCN(n_feat=dataset.num_node_features,
n_hidden=20,
n_class=dataset.num_classes,
pooling=pooling[dataset_name],
loop=loop)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
loss_fc = nn.CrossEntropyLoss()
model_file = './src/' + dataset_name + '.pt'
model.train()
early_stop_count = 0
best_acc = 0
best_loss = 100
for epoch in range(epochs):
acc = []
loss_list = []
model.train()
for i, data in enumerate(data_loader['train']):
logit = model(data)
loss = loss_fc(logit, data.y)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
optimizer.step()
prediction = torch.argmax(logit, -1)
loss_list.append(loss.item())
acc.append((prediction == data.y).numpy())
eval_acc, eval_loss = evaluate(dataloader=data_loader['eval'], model=model, loss_fc=loss_fc)
print(epoch, eval_acc, eval_loss)
is_best = (eval_acc > best_acc) or \
(eval_loss < best_loss and eval_acc >= best_acc)
if is_best:
early_stop_count = 0
else:
early_stop_count += 1
if early_stop_count > early_stop:
break
if is_best:
best_acc = eval_acc
best_loss = eval_loss
early_stop_count = 0
model.save(model_file)
model.load(model_file)
model.eval()
acc_test, acc_loss = evaluate(data_loader['test'], model, loss_fc)
print(acc_test)