forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhde_trainer.py
107 lines (98 loc) · 4.7 KB
/
hde_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch as th
from torch import nn
from tqdm import tqdm
import torch
from . import BaseFlow, register_flow
from ..models import build_model
from ..utils import extract_embed, EarlyStopping, get_nodes_dict
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from ..sampler import hde_sampler
@register_flow("hde_trainer")
class hde_trainer(BaseFlow):
"""
HDE trainer flow.
Supported Model: HDE
Supported Dataset:imdb4hde
The trainerflow supports link prediction task. It first calculates HDE for every node in the graph,
and uses the HDE as a part of the initial feature for each node.
And then it performs standard message passing and link prediction operations.
Please notice that for different target node set, the HDE for each node can be different.
For more details, please refer to the original paper: http://www.shichuan.org/doc/116.pdf
"""
def __init__(self, args):
super(hde_trainer, self).__init__(args)
self.target_link = self.task.dataset.target_link
self.loss_fn = self.task.get_loss_fn()
self.args.out_node_type = self.task.dataset.out_ntypes
self.type2idx = {'A': 0, 'B': 1}
self.node_type = len(self.type2idx)
self.num_fea = (self.node_type * (args.max_dist + 1)) * 2 + self.node_type
self.sample_size = self.num_fea * (1 + args.num_neighbor + args.num_neighbor * args.num_neighbor)
self.args.patience = 10
args.input_dim = self.num_fea
args.output_dim = args.emb_dim // 2
self.model = build_model(self.model).build_model_from_args(self.args, self.hg).to(self.device)
# initialization
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_in')
self.loss_fn = nn.CrossEntropyLoss().to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=1e-2)
self.evaluator = roc_auc_score
self.HDE_sampler = hde_sampler.HDESampler(self)
self.HDE_sampler.dgl2nx()
self.data_A_train, self.data_B_train, self.data_y_train, \
self.val_batch_A_fea, self.val_batch_B_fea, self.val_batch_y, \
self.test_batch_A_fea, self.test_batch_B_fea, self.test_batch_y = self.HDE_sampler.compute_hde(args)
def train(self):
epoch_iter = tqdm(range(self.max_epoch))
stopper = EarlyStopping(self.args.patience, self._checkpoint)
for epoch in tqdm(range(self.max_epoch), ncols=80):
loss = self._mini_train_step()
if epoch % 2 == 0:
val_metric = self._test_step('valid')
epoch_iter.set_description(
f"Epoch: {epoch:03d}, roc_auc: {val_metric:.4f}, Loss:{loss:.4f}"
)
early_stop = stopper.step_score(val_metric, self.model)
if early_stop:
print('Early Stop!\tEpoch:' + str(epoch))
break
print(f"Valid_score_ = {stopper.best_score: .4f}")
stopper.load_model(self.model)
test_auc = self._test_step(split="test")
val_auc = self._test_step(split="valid")
print(f"Test roc_auc = {test_auc:.4f}")
return dict(Test_mrr=test_auc, Val_mrr=val_auc)
def _mini_train_step(self,):
self.model.train()
all_loss = 0
for (train_batch_A_fea, train_batch_B_fea, train_batch_y) in zip(self.data_A_train, self.data_B_train, self.data_y_train):
# train
self.model.train()
train_batch_A_fea = torch.FloatTensor(train_batch_A_fea).to(self.device)
train_batch_B_fea = torch.FloatTensor(train_batch_B_fea).to(self.device)
train_batch_y = torch.LongTensor(train_batch_y).to(self.device)
logits = self.model(train_batch_A_fea, train_batch_B_fea)
loss = self.loss_fn(logits, train_batch_y.squeeze())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
all_loss += loss.item()
return all_loss
def _test_step(self, split=None, logits=None):
self.model.eval()
with th.no_grad():
if split == 'valid':
data_A_eval = self.val_batch_A_fea
data_B_eval = self.val_batch_B_fea
data_y_eval = self.val_batch_y
elif split == 'test':
data_A_eval = self.test_batch_A_fea
data_B_eval = self.test_batch_B_fea
data_y_eval = self.test_batch_y
logits = self.model(data_A_eval, data_B_eval)
pred = logits.argmax(dim=1)
metric = self.evaluator(data_y_eval.cpu().numpy(), pred.cpu().numpy())
return metric