-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsurrogate_model.py
178 lines (144 loc) · 7.07 KB
/
surrogate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import json
import logging
import os
import sys
from abc import ABC, abstractmethod
import numpy as np
import pathvalidate
#from nasbench301.surrogate_models import utils
import utils
class SurrogateModel(ABC):
def __init__(self, data_root, log_dir, seed, model_config, data_config):
self.data_root = data_root
self.log_dir = log_dir
self.model_config = model_config
self.data_config = data_config
self.seed = seed
# Seeding
np.random.seed(seed)
# NOTE: Update to use absolute path, also moved configspace to
# be included in the installed package
current_dir = os.path.dirname(os.path.abspath(__file__))
#nasbench_root = os.path.join(current_dir, os.pardir)
configspace_path = os.path.join(current_dir, 'configurationspaces', 'effnet_configspace.json')
# Create config loader
self.config_loader = utils.ConfigLoader(configspace_path)
# Load the data
# print('-=-=-='*10)
if log_dir is not None:
os.makedirs(log_dir, exist_ok=True)
# Add logger
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
# Dump the config of the run to log_dir
self.data_config['seed'] = seed
logging.info('MODEL CONFIG: {}'.format(model_config))
logging.info('DATA CONFIG: {}'.format(data_config))
self._load_data()
logging.info(
'DATA: No. train data {}, No. val data {}, No. test data {}'.format(len(self.train_paths),
len(self.val_paths),
len(self.test_paths)))
with open(os.path.join(log_dir, 'model_config.json'), 'w') as fp:
json.dump(model_config, fp)
with open(os.path.join(log_dir, 'data_config.json'), 'w') as fp:
json.dump(data_config, fp)
# logging.info('-=-=-='*10)
#with open(os.path.join(log_dir, 'train_paths.json'), 'w') as fp:
# json.dump(self.train_paths, fp)
#with open(os.path.join(log_dir, 'val_paths.json'), 'w') as fp:
# json.dump(self.val_paths, fp)
#with open(os.path.join(log_dir, 'test_paths.json'), 'w') as fp:
# json.dump(self.test_paths, fp)
def _load_data(self):
# Get the result train/val/test split
train_paths = []
val_paths = []
test_paths = []
for key, data_config in self.data_config.items():
if type(data_config) == dict:
result_loader = utils.ResultLoader(
self.data_root, filepath_regex=data_config['filepath_regex'],
train_val_test_split=data_config, seed=self.seed)
# [506, 1011, 1516, 2021, 2526, 3031, 3536, 4041, 4546, 5051]
train_val_test_split = result_loader.return_train_val_test()
# Save the paths
for paths, filename in zip(train_val_test_split, ['train_paths', 'val_paths', 'test_paths']):
file_path = os.path.join(self.log_dir,
pathvalidate.sanitize_filename('{}_{}.json'.format(key, filename)))
json.dump(paths, open(file_path, 'w'))
train_paths.extend(train_val_test_split[0])
val_paths.extend(train_val_test_split[1])
test_paths.extend(train_val_test_split[2])
'''
# Add extra paths to test
# Increased ratio of skip-connections.
matching_files = lambda dir: [str(path) for path in Path(os.path.join(self.data_root, dir)).rglob('*.json')]
test_paths.extend(matching_files('groundtruths/low_parameter/'))
# Extreme hyperparameter settings
# Learning rate
test_paths.extend(matching_files('groundtruths/hyperparameters/learning_rate/'))
test_paths.extend(matching_files('groundtruths/hyperparameters/weight_decay/'))
# Load the blacklist to filter out those elements
if self.model_config["model"].endswith("_time"):
blacklist = json.load(open('surrogate_models/configs/data_configs/blacklist_runtimes.json'))
else:
blacklist = json.load(open('surrogate_models/configs/data_configs/blacklist.json'))
filter_out_black_list = lambda paths: list(filter(lambda path: path not in blacklist, paths))
train_paths, val_paths, test_paths = map(filter_out_black_list, [train_paths, val_paths, test_paths])
'''
# Shuffle the total file paths again
rng = np.random.RandomState(6)
rng.shuffle(train_paths)
rng.shuffle(val_paths)
rng.shuffle(test_paths)
self.train_paths = train_paths
self.val_paths = val_paths
self.test_paths = test_paths
def _get_labels_and_preds(self, result_paths):
"""Get labels and predictions from json paths"""
labels = []
preds = []
for result_path in result_paths:
config_space_instance, val_accuracy_true, test_accuracy_true, _ = self.config_loader[result_path]
val_pred = self.query(config_space_instance.get_dictionary())
labels.append(val_accuracy_true)
preds.append(val_pred)
return labels, preds
def _log_predictions(self, result_paths, labels, preds, identifier):
"""Log paths, labels and predictions for one split"""
if not isinstance(preds[0], float):
preds = [p[0] for p in preds]
logdir = os.path.join(self.log_dir, identifier+"_preds.json")
dump_dict = {"paths": result_paths, "labels": labels, "predictions": preds}
with open(logdir, "w") as f:
json.dump(dump_dict, f)
def log_dataset_predictions(self):
"""Log paths, labels and predictions for train, val, test splits"""
data_splits = {"train": self.train_paths, "val": self.val_paths, "test": self.test_paths}
for split_identifier, result_paths in data_splits.items():
print("==> Logging predictions of %s split" %split_identifier)
labels, preds = self._get_labels_and_preds(result_paths)
self._log_predictions(result_paths, labels, preds, split_identifier)
@abstractmethod
def train(self):
raise NotImplementedError()
@abstractmethod
def validate(self):
raise NotImplementedError()
@abstractmethod
def test(self):
raise NotImplementedError()
@abstractmethod
def save(self):
raise NotImplementedError()
@abstractmethod
def load(self, model_path):
raise NotImplementedError()
@abstractmethod
def query(self, config_dict):
raise NotImplementedError()