-
Notifications
You must be signed in to change notification settings - Fork 0
/
launcher_scalability.py
113 lines (94 loc) · 4.02 KB
/
launcher_scalability.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import config
from utils_exp import generate_base_command, generate_run_commands, hash_dict
from config import RESULT_DIR
import run_phasedgp
import argparse
import numpy as np
import copy
import os
import itertools
applicable_configs = {
'GNN-UCB':['exploration_coef','pretrain_steps', 'alg_lambda','neuron_per_layer', 'net', 't_intersect'],
'Dataset': ['num_nodes','feat_dim','edge_prob', 'num_actions']
}
default_configs = {
# Dataset
'num_nodes': 5, # or 20 or 100
'edge_prob': 0.05, #or 0.2 or 0.95
'feat_dim': 10, #or 100
'num_actions': 200, # any number below 10000 works.
# GNN-UCB
'pretrain_steps': config.alg_pretrain_steps['GNN_US'][0],
'neuron_per_layer': 2048,
'exploration_coef': config.alg_betas['GNN_US'][0],
'alg_lambda': config.alg_lambdas['GNN_US'][0],
't_intersect': config.alg_intersect['GNN_US'][0],
'net': 'GNN'
}
search_ranges = {}
# check consistency of configuration dicts
assert set(itertools.chain(*list(applicable_configs.values()))) == {*default_configs.keys(), *search_ranges.keys()}
def sample_flag(sample_spec, rds=None):
if rds is None:
rds = np.random
assert len(sample_spec) == 2
sample_type, range = sample_spec
if sample_type == 'loguniform':
assert len(range) == 2
return 10**rds.uniform(*range)
elif sample_type == 'uniform':
assert len(range) == 2
return rds.uniform(*range)
elif sample_type == 'choice':
return rds.choice(range)
elif sample_type == 'intuniform':
return rds.randint(*range)
else:
raise NotImplementedError
def main(args):
rds = np.random.RandomState(args.seed)
assert args.num_seeds_per_hparam < 101
init_seeds = list(rds.randint(0, 10**6, size=(101,)))
# determine name of experiment
exp_base_path = os.path.join(RESULT_DIR, args.exp_name)
#exp_path = os.path.join(exp_base_path, '%s'%(args.net))
exp_path = exp_base_path
command_list = []
for _ in range(args.num_hparam_samples):
# transfer flags from the args
flags = copy.deepcopy(args.__dict__)
[flags.pop(key) for key in ['seed', 'num_hparam_samples', 'num_seeds_per_hparam', 'exp_name', 'num_cpus']]
# randomly sample flags
for flag in default_configs:
if flag in search_ranges:
flags[flag] = sample_flag(sample_spec=search_ranges[flag], rds=rds)
else:
flags[flag] = default_configs[flag]
# determine subdir which holds the repetitions of the exp
flags_hash = hash_dict(flags)
flags['exp_result_folder'] = os.path.join(exp_path, flags_hash)
for j in range(args.num_seeds_per_hparam):
seed = init_seeds[j]
cmd = generate_base_command(run_phasedgp, flags=dict(**flags, **{'seed': seed}))
command_list.append(cmd)
# submit jobs
generate_run_commands(command_list, num_cpus=args.num_cpus, mode='euler', promt=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Scalability UCB')
# experiment parameters
parser.add_argument('--exp_name', type=str, default='scalability_us_new')
parser.add_argument('--num_cpus', type=int, default=8)
parser.add_argument('--num_hparam_samples', type=int, default=1)
parser.add_argument('--num_seeds_per_hparam', type=int, default=20)
parser.add_argument('--exp_result_folder', type=str, default=None)
parser.add_argument('--data', type=str, default='synthetic_data', help='dataset type')
parser.add_argument('--seed', type=int, default=864, help='random number generator seed')
# model arguments
# this is to set algo params that you don't often want to change
parser.add_argument('--nn_aggr_feat', type=bool, default=True)
parser.add_argument('--nn_init_lazy', type=bool, default=True)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--T', type=int, default=500) #change to 1500
parser.add_argument('--T0', type=int, default=100)
args = parser.parse_args()
main(args)