forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base_flow.py
176 lines (153 loc) · 6.48 KB
/
base_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import torch
from abc import ABC, abstractmethod
from ..tasks import build_task
from ..layers.HeteroLinear import HeteroFeature
from ..utils import get_nodes_dict
class BaseFlow(ABC):
candidate_optimizer = {
'Adam': torch.optim.Adam,
'SGD': torch.optim.SGD,
'Adadelta': torch.optim.Adadelta
}
def __init__(self, args):
"""
Parameters
----------
args
Attributes
-------------
evaluate_interval: int
the interval of evaluation in validation
"""
super(BaseFlow, self).__init__()
self.evaluator = None
self.evaluate_interval = 1
if hasattr(args, '_checkpoint'):
self._checkpoint = os.path.join(args._checkpoint, f"{args.model_name}_{args.dataset_name}.pt")
else:
if hasattr(args, 'load_from_pretrained'):
self._checkpoint = os.path.join(args.output_dir,
f"{args.model_name}_{args.dataset_name}_{args.task}.pt")
else:
self._checkpoint = None
if not hasattr(args, 'HGB_results_path') and args.dataset_name[:3] == 'HGB':
args.HGB_results_path = os.path.join(args.output_dir,
"{}_{}_{}.txt".format(args.model_name, args.dataset_name[5:],
args.seed))
self.args = args
self.logger = self.args.logger
self.model_name = args.model_name
self.model = args.model
self.device = args.device
self.task = build_task(args)
self.hg = self.task.get_graph().to(self.device)
self.args.meta_paths_dict = self.task.dataset.meta_paths_dict
self.patience = args.patience
self.max_epoch = args.max_epoch
self.optimizer = None
self.loss_fn = self.task.get_loss_fn()
def preprocess(self):
r"""
Every trainerflow should run the preprocess_feature if you want to get a feature preprocessing.
The Parameters in input_feature will be added into optimizer and input_feature will be added into the model.
Attributes
-----------
input_feature : HeteroFeature
It will return the processed feature if call it.
"""
if hasattr(self.args, 'activation'):
if hasattr(self.args.activation, 'weight'):
import torch.nn as nn
act = nn.PReLU()
else:
act = self.args.activation
else:
act = None
# useful type selection
if hasattr(self.args, 'feat'):
pass
else:
# Default 0, nothing to do.
self.args.feat = 0
self.feature_preprocess(act)
self.optimizer.add_param_group({'params': self.input_feature.parameters()})
# for early stop, load the model with input_feature module.
self.model.add_module('input_feature', self.input_feature)
self.load_from_pretrained()
def feature_preprocess(self, act):
"""
Feat
0, 1 ,2
Node feature
1 node type & more than 1 node types
no feature
Returns
-------
"""
if self.hg.ndata.get('h', {}) == {} or self.args.feat == 2:
if self.hg.ndata.get('h', {}) == {}:
self.logger.feature_info('Assign embedding as features, because hg.ndata is empty.')
else:
self.logger.feature_info('feat2, drop features!')
self.hg.ndata.pop('h')
self.input_feature = HeteroFeature({}, get_nodes_dict(self.hg), self.args.hidden_dim,
act=act).to(self.device)
elif self.args.feat == 0:
self.input_feature = self.init_feature(act)
elif self.args.feat == 1:
if self.args.task != 'node_classification':
self.logger.feature_info('\'feat 1\' is only for node classification task, set feat 0!')
self.input_feature = self.init_feature(act)
else:
h_dict = self.hg.ndata.pop('h')
self.logger.feature_info('feat1, preserve target nodes!')
self.input_feature = HeteroFeature({self.category: h_dict[self.category]}, get_nodes_dict(self.hg), self.args.hidden_dim,
act=act).to(self.device)
def init_feature(self, act):
self.logger.feature_info("Feat is 0, nothing to do!")
if isinstance(self.hg.ndata['h'], dict):
# The heterogeneous contains more than one node type.
input_feature = HeteroFeature(self.hg.ndata['h'], get_nodes_dict(self.hg),
self.args.hidden_dim, act=act).to(self.device)
elif isinstance(self.hg.ndata['h'], torch.Tensor):
# The heterogeneous only contains one node type.
input_feature = HeteroFeature({self.hg.ntypes[0]: self.hg.ndata['h']}, get_nodes_dict(self.hg),
self.args.hidden_dim, act=act).to(self.device)
return input_feature
@abstractmethod
def train(self):
pass
def _full_train_step(self):
r"""
Train with a full_batch graph
"""
raise NotImplementedError
def _mini_train_step(self):
r"""
Train with a mini_batch seed nodes graph
"""
raise NotImplementedError
def _full_test_step(self):
r"""
Test with a full_batch graph
"""
raise NotImplementedError
def _mini_test_step(self):
r"""
Test with a mini_batch seed nodes graph
"""
raise NotImplementedError
def load_from_pretrained(self):
if hasattr(self.args, 'load_from_pretrained') and self.args.load_from_pretrained:
try:
ck_pt = torch.load(self._checkpoint)
self.model.load_state_dict(ck_pt)
self.logger.info('[Load Model] Load model from pretrained model:' + self._checkpoint)
except FileNotFoundError:
self.logger.info('[Load Model] Do not load the model from pretrained, '
'{} doesn\'t exists'.format(self._checkpoint))
# return self.model
def save_checkpoint(self):
if self._checkpoint and hasattr(self.model, "_parameters()"):
torch.save(self.model.state_dict(), self._checkpoint)