forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhetgnn_trainer.py
134 lines (112 loc) · 5.26 KB
/
hetgnn_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import dgl
import torch as th
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from ..models import build_model
from . import BaseFlow, register_flow
from ..sampler.HetGNN_sampler import SkipGramBatchSampler, HetGNNCollator, NeighborSampler, hetgnn_graph
from ..utils import EarlyStopping
@register_flow("hetgnntrainer")
class HetGNNTrainer(BaseFlow):
"""SkipGram flows.
Supported Model: HetGNN
Supported Dataset:Academic4HetGNN
Dataset description can be found in HetGNN paper.
The trainerflow supports node classification and author link prediction.
"""
def __init__(self, args):
super(HetGNNTrainer, self).__init__(args)
self.category = self.task.dataset.category
self.model = build_model(self.model).build_model_from_args(self.args, self.hg).to(self.device)
self.optimizer = (
th.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
)
def preprocess(self):
if self.args.mini_batch_flag:
if self.args.model == 'HetGNN':
hetg = hetgnn_graph(self.hg, self.args.dataset)
self.hg = self.hg.to('cpu')
self.het_graph = hetg.get_hetgnn_graph(self.args.rw_length, self.args.rw_walks, self.args.rwr_prob).to('cpu')
batch_sampler = SkipGramBatchSampler(self.hg, self.args.batch_size, self.args.window_size)
neighbor_sampler = NeighborSampler(self.het_graph, self.hg.ntypes, batch_sampler.num_nodes, self.args.device)
collator = HetGNNCollator(neighbor_sampler, self.hg)
dataloader = DataLoader(
batch_sampler,
collate_fn=collator.collate_train,
num_workers=self.args.num_workers)
self.dataloader_it = iter(dataloader)
self.hg = self.hg.to(self.args.device)
self.het_graph = self.het_graph.to(self.args.device)
return
def train(self):
self.preprocess()
stopper = EarlyStopping(self.args.patience, self._checkpoint)
epoch_iter = tqdm(range(self.max_epoch))
for epoch in epoch_iter:
if self.args.mini_batch_flag:
loss = self._mini_train_step()
else:
loss = self._full_train_setp()
if epoch % self.evaluate_interval == 0:
early_stop = stopper.loss_step(loss, self.model)
metrics = self._test_step()
self.logger.train_info(f"Epoch{epoch}, train loss:{loss:.4f}" + self.logger.metric2str(metrics))
if early_stop:
self.logger.train_info(f"Early Stop!\tEpoch:{epoch}")
break
stopper.load_model(self.model)
metrics = self._test_step()
return metrics
def _full_train_setp(self):
self.model.train()
negative_graph = self.construct_negative_graph()
x = self.model(self.het_graph)[self.category]
loss = self.loss_calculation(self.ScorePredictor(self.hg, x), self.ScorePredictor(negative_graph, x))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def _mini_train_step(self, ):
self.model.train()
all_loss = 0
for batch_id in range(self.args.batches_per_epoch):
positive_graph, negative_graph, blocks = next(self.dataloader_it)
blocks = [b.to(self.device) for b in blocks]
positive_graph = positive_graph.to(self.device)
negative_graph = negative_graph.to(self.device)
# we need extract multi-feature
x = self.model(blocks[0])
loss = self.loss_calculation(self.ScorePredictor(positive_graph, x), self.ScorePredictor(negative_graph, x))
all_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return all_loss/self.args.batches_per_epoch
def loss_calculation(self, pos_score, neg_score):
# an example hinge loss
loss = []
for i in pos_score:
loss.append(F.logsigmoid(pos_score[i]))
loss.append(F.logsigmoid(-neg_score[i]))
loss = th.cat(loss)
return -loss.mean()
def ScorePredictor(self, edge_subgraph, x):
with edge_subgraph.local_scope():
edge_subgraph.ndata['x'] = x
for etype in edge_subgraph.canonical_etypes:
edge_subgraph.apply_edges(
dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
return edge_subgraph.edata['score']
def _test_step(self, logits=None):
self.model.eval()
with th.no_grad():
h = self.model.extract_feature(self.hg, self.hg.ntypes)
logits = logits if logits else self.model(self.het_graph, h)
logits = logits[self.category].to('cpu')
if self.args.task == 'node_classification':
metric = self.task.downstream_evaluate(logits, 'f1_lr')
return {'test': metric}
elif self.args.task == 'link_prediction':
metric = self.task.downstream_evaluate(logits, 'academic_lp')
return {'test': metric}