-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtune.py
320 lines (265 loc) · 12.2 KB
/
tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import argparse
import optuna
import optuna.study
import os
import numpy as np
import random
import torch as th
import torch.nn.functional as F
import time
from opts.tune.public_hypers import public_hypers_default
from opts.tune.public_hypers import convert_dict_to_optuna_suggested
from opts.tune.public_static_settings import public_static_opts
from opts.tune.public_hypers import public_hypers_default
from opts.tune.private_static_settings import *
from opts.tune.private_hypers import *
from utils.optuna_utils import _ckpt_fname
from utils.optuna_utils import _get_complete_and_pruned_trial_nums
from utils.optuna_utils import _pruneDuplicate, _CkptsAndHandlersClearerCallBack
from utils.exp_utils import _prepare_optuna_cache_dir, _record_tuning_opts_in_optuna_cache_dir
from utils.data_utils import build_dataset
from utils.model_utils import build_model, build_optimizers
from utils.random_utils import reset_random_seeds
from utils.grading_logger import _set_logger
from utils.model_utils import bce_with_logits_loss
from utils.stopper import EarlyStopping
from utils.rocauc_eval import fast_auc_th, acc
from torch_geometric.nn.conv.gcn_conv import gcn_norm
def evaluate(logits, labels, mask, evaluator):
if not th.is_tensor(logits):
logits = logits[0]
logits = logits[mask]
labels = labels[mask]
metric = evaluator(logits, labels)
return metric
def run(args, logger, trial,
# edge_index, data, norm_A, features, labels,
data,
model_seed
):
dur = []
edge_index = data.edge_index
norm_A = data.norm_A
features = data.features
labels = data.labels
# split dataset for this run
if args.dataset in ['twitch-gamer', 'Penn94', 'genius', 'tolokers', 'minesweeper', 'roman-empire']:
# encouraged to use fixed splits
data.load_mask()
else:
# Use random splits
data.load_mask(p=(0.6,0.2,0.2))
reset_random_seeds(model_seed)
if args.dataset in ['genius', 'minesweeper', 'tolokers', 'questions']:
loss_fcn = bce_with_logits_loss
evaluator = fast_auc_th
else:
loss_fcn = F.cross_entropy # input: logits (N, C) and labels (N,)
evaluator = acc
data.in_feats = features.shape[-1]
model = build_model(args,
edge_index,
edge_weights=norm_A,
in_feats=data.in_feats,
n_classes=data.n_classes
)
optimizers = build_optimizers(args, model)
if args.early_stop:
stopper = EarlyStopping(patience=args.patience, store_path=args.es_ckpt+'.pt')
stopper_step = stopper.step
for epoch in range(args.n_epochs):
t0 = time.time()
# train
model.train()
for _ in optimizers:
_.zero_grad()
logits = model(features)
loss_train = loss_fcn(logits[data.train_mask], labels[data.train_mask])
loss_train.backward()
for _ in optimizers:
_.step()
# eval
model.eval()
with th.no_grad():
logits = model(features)
loss_val = loss_fcn(logits[data.val_mask], labels[data.val_mask])
acc_val = evaluate(logits, labels, data.val_mask, evaluator)
trial.report(acc_val, epoch)
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
acc_train = evaluate(logits, labels, data.train_mask, evaluator)
# end of the epoch
dur.append(time.time() - t0)
if (epoch+1) % 20 == 0 :
logger.debug("Epoch {:05d} | Time(s) {:.4f} | Val Loss {:.4f} | Val Acc {:.4f} | Train Acc {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(epoch+1, np.mean(dur), loss_val.item(),
acc_val, acc_train,
data.n_edges/ np.mean(dur) / 100)
)
if args.early_stop and epoch >= 0:
if stopper_step(acc_val, model):
break
# end training
if args.early_stop:
model.load_state_dict(th.load(stopper.store_path))
logger.debug('Model saved by early stopper is loaded!')
# report the final model
model.eval()
with th.no_grad():
logits = model(features)
loss_val = loss_fcn(logits[data.val_mask], labels[data.val_mask])
loss_test = loss_fcn(logits[data.test_mask], labels[data.test_mask])
acc_val = evaluate(logits, labels, data.val_mask, evaluator)
acc_test = evaluate(logits, labels, data.test_mask, evaluator)
logger.warning("[FINAL MODEL] Val accuracy {:.2%} \Val loss: {:.2}".format(acc_val, loss_val))
logger.warning("[FINAL MODEL] Test accuracy {:.2%} \Test loss: {:.2}".format(acc_test, loss_test))
return acc_val, acc_test
def main(args, logger, trial):
reset_random_seeds(args.seed)
data = build_dataset(args)
data.set_split_seeds()
model_seeds = [random.randint(0,10000)]
logger.info('Split_seeds:{:s}'.format(str(data.seeds)))
logger.info('Model_seeds:{:s}'.format(str(model_seeds)))
edge_index = data.edge_index
if args.graph_norm == 'sym':
_, data.norm_A = gcn_norm(edge_index, add_self_loops=False)
elif args.graph_norm == 'none':
data.norm_A = th.ones(data.n_edges, device=data.device)
else:
raise NotImplementedError("Case for Rescaled Laplacian Not Implemented!")
# For tuning stage, I use only the first cross-validation split
cv_id = 0
val_acc, test_acc = run(args, logger, trial,
# edge_index, data, norm_A, features, labels,
data,
model_seed=model_seeds[cv_id]
)
# logger.info("Acc on the first split (Validation Set): {:.4f}".format(val_acc))
# logger.info("Acc on the first split (Test Set): {:.4f}".format(test_acc))
return val_acc.item(), test_acc.item()
def initialize_args():
# 1. Set static args: Define the static arguments shared across tasks and models
## 1.1 Static options shared by all tasks
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument("--model", type=str, default='OptBasisGNN')
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--dataset", type=str, default="cora")
# `--lcc` option: put into `public_static_opts`, not here
# `--graph-norm` option: also put into `public_static_opts`, not here
## Log options
# parser.add_argument("--log-ignore-steps", action='store_true', default=False, help='Do not report intermediate results for each 50/20 epochs (default: False)')
# parser.add_argument("--logging", action='store_true', default=False, help="Enable logging to files (default: False)")
# parser.add_argument("--log-id", type=int, default=0, help="ID for log directory (default: 0)")
# parser.add_argument("--log-detailed-console", action='store_true', default=False, help="Enable detailed logging in the console (default: False)")
parser.add_argument("--file-logging", action='store_true', default=False, help="Enable logging to files (default: False)")
parser.add_argument("--file-log-id", type=int, default=0, help="ID for log directory (default: 0)")
parser.add_argument("--detailed-console-logging", action='store_true', default=False, help="Enable detailed logging in the console (default: False)")
## Tuning and training options
parser.add_argument("--optuna-n-trials", type=int, default=202)
parser.add_argument("--n-epochs", type=int, default=2000)
parser.add_argument("--study-kw", type=str, required=True, help="Keyword for the study; for better recording")
static_args = parser.parse_args()
if static_args.gpu < 0:
static_args.gpu = 'cpu'
## 1.2 Additional static options shared by all tasks (Part II)
dargs = vars(static_args)
## 1.3 Static options specific to the model
dargs.update(public_static_opts)
if f'{static_args.model}_static_opts' in globals().keys():
k = f'{static_args.model}_static_opts'
dargs.update(globals()[k])
# 2. Args to be tuned: Define the hyperparameters to be tuned by Optuna
# Initialize a `suggestor` using the function `convert_dict_to_optuna_suggested`.
# The suggestor will generate a group of hyperparameters for a specific run (see function `convert_dict_to_optuna_suggested`).
# Most hyperparameters are shared across different models, such as learning rates and weight decays.
## 2.1 Public hyperparameters default settings
to_tune = public_hypers_default
## 2.2 Private hyperparameters specific to the model
if f'{static_args.model}_opts' in globals().keys():
k = f'{static_args.model}_opts'
to_tune.update(globals()[k])
else:
model_ = static_args.model.split('_')[0]
if f'{model_}_opts' in globals().keys():
k = f'{model_}_opts'
to_tune.update(globals()[k])
# End of 2. : Get a suggestor!
suggestor = convert_dict_to_optuna_suggested(to_tune, static_args.model)
return static_args, suggestor
def objective(trial):
# arguments
suggested_args = suggestor(trial)
# args = {} # create an empty namespace object
args = argparse.Namespace()
dargs = vars(args)
dargs.update(vars(static_args))
dargs.update(suggested_args)
dargs.update({'es_ckpt': _ckpt_fname(trial.study, trial)})
# logger
logger = _set_logger(args)
logger.info(args)
# If the same set of params have been tested,
# it will be pruned;
# in this case an exception handled by optuna will be raised
_pruneDuplicate(trial)
# report args
# run
val_acc, test_acc = main(args, logger, trial)
trial.set_user_attr("val_acc", val_acc)
trial.set_user_attr("test_acc", test_acc)
return val_acc
if __name__ == '__main__':
global static_args
global suggestor
static_args, suggestor = initialize_args()
# Create an optuna study
dataset = static_args.dataset
db_name = f'{static_args.model}-{dataset}'
dir_name, opts_rec_path = _prepare_optuna_cache_dir(static_args)
study = optuna.create_study(
study_name="{}".format(db_name),
direction="maximize",
storage = optuna.storages.RDBStorage(url='sqlite:///{}/{}.db'.format(dir_name, db_name),
engine_kwargs={"connect_args": {"timeout": 10000}}),
pruner=optuna.pruners.MedianPruner(n_startup_trials=5,n_warmup_steps=15,interval_steps=1,n_min_trials=5),
load_if_exists=True
)
study.set_user_attr('kw', static_args.study_kw)
# Run trials
n_trials = static_args.optuna_n_trials
num_completed, num_pruned = _get_complete_and_pruned_trial_nums(study)
rec_flag = False
while num_completed + num_pruned < n_trials:
print('=='*10 + '\n{} trials to go!'.format(n_trials - num_completed - num_pruned), flush=True )
# One trial each time
study.optimize(objective,
n_trials=1,
catch=(RuntimeError,),
callbacks=(_CkptsAndHandlersClearerCallBack(),)
)
num_completed, num_pruned = _get_complete_and_pruned_trial_nums(study)
if num_pruned > 1000:
break
if rec_flag is False:
_record_tuning_opts_in_optuna_cache_dir(study, opts_rec_path)
rec_flag = True
# Report results
print("Study statistics this: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", num_pruned)
print(" Number of complete trials: ", num_completed)
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
print(" Information for this best trial: ")
for key, value in trial.user_attrs.items():
print(" {}: {}".format(key, value))
print("REMARK: These are only results for **one run**! "
"\nDO NOT report the test acc here as the final result in your paper. "
"\nInsteat, use `train.py` to conduct further duplicative experiments."
"\n")