-
Notifications
You must be signed in to change notification settings - Fork 0
/
hyper_tune.py
157 lines (133 loc) · 7.33 KB
/
hyper_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from ray.tune import Trainable
from train import setup_training, get_optimizers, training_step
from eval import Evaluation
import torch
import os
from average import EWMA
from colorama import Fore
from colorama import Style
from helpers import flatten_dict
import uuid
def replace_value_in_nested_dict(node, kv, new_value):
if isinstance(node, list):
for i in node:
for x in replace_value_in_nested_dict(i, kv, new_value):
yield x
elif isinstance(node, dict):
if kv in node:
node[kv] = new_value
yield node[kv]
for j in node.values():
for x in replace_value_in_nested_dict(j, kv, new_value):
yield x
# TODO global config is a kludge
def inject_tuned_hyperparameters(global_config, config):
for k in config.keys():
if k.split('|')[0] == 'allocate':
list(replace_value_in_nested_dict(global_config, k.split('|')[1], config[k]))
elif k.split('|')[0] == 'nested':
inject_tuned_hyperparameters(global_config, config[k])
return config
class TuneTrainable(Trainable):
def _setup(self, config):
inject_tuned_hyperparameters(config, config)
os.chdir(os.path.dirname(os.path.realpath(__file__)))
print('Trainable got the following config after injection', config)
self.config = config
self.device = self.config['device']
self.exp, self.model, self.train_dataloader, self.eval_dataloader = setup_training(self.config)
self.exp.set_name(config['experiment_name'] + self._experiment_id)
self.exp_name = config['experiment_name'] + self._experiment_id
self.exp.send_notification(title='Experiment ' + str(self._experiment_id) + ' ended')
self.train_data_iter = iter(self.train_dataloader)
self.model = self.model.to(self.device)
self.model.train()
n_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
log_dict = flatten_dict(config)
log_dict.update({'trainable_params': n_params})
self.exp.log_parameters(log_dict)
self.optimizers = get_optimizers(self.model, self.config)
self.evaluator = Evaluation(self.eval_dataloader, self.config)
self.num_examples = 0
self.batch_idx = 0
self.epoch = 1
self.ewma = EWMA(beta=0.75)
self.last_accu = -1.0
self.max_accu = -1.0
self.back_prop_every_n_batches = config['training']['back_prop_every_n_batches']
self.checkpoint_best = config['training']['checkpoint_best']
def get_batch(self):
try:
batch = next(self.train_data_iter)
return batch
except StopIteration:
self.train_data_iter = iter(self.train_dataloader)
batch = next(self.train_data_iter)
self.batch_idx = 0
self.epoch += 1
return batch
def _train(self):
total_log_step_loss = 0
total_log_step_train_accu = 0
total_log_step_n = 0
[opt.zero_grad() for opt in self.optimizers]
while True:
batch = self.get_batch()
self.batch_idx += 1
self.num_examples += len(batch[0])
batch = (batch[0].to(self.device), batch[1].to(self.device))
loss, train_accu = training_step(batch, self.model, self.optimizers, step=(self.batch_idx % self.back_prop_every_n_batches == 0))
total_log_step_loss += loss.cpu().detach().numpy()
total_log_step_train_accu += train_accu
total_log_step_n += 1
if self.batch_idx % self.config['training']['log_every_n_batches'] == 0:
avg_loss = total_log_step_loss / total_log_step_n
avg_accu = total_log_step_train_accu / total_log_step_n
total_log_step_n = 0
print(f'{Fore.YELLOW}Total number of seen examples:', self.num_examples, 'Average loss of current log step:',
avg_loss, 'Average train accuracy of current log step:', avg_accu, f"{Style.RESET_ALL}")
self.exp.log_metric('train_loss', avg_loss, step=self.num_examples, epoch=self.epoch)
self.exp.log_metric('train_accuracy', avg_accu, step=self.num_examples, epoch=self.epoch)
total_log_step_loss = 0
total_log_step_train_accu = 0
if (self.batch_idx + 1) % self.config['training']['eval_every_n_batches'] == 0:
results, assets, image_fns = self.evaluator.eval_model(self.model)
print(self.config['tune']['discriminating_metric'], results[self.config['tune']['discriminating_metric']])
self.exp.log_metrics(results, step=self.num_examples, epoch=self.epoch)
[self.exp.log_asset_data(asset, step=self.num_examples) for asset in assets]
[self.exp.log_image(fn, step=self.num_examples) for fn in image_fns]
accu_diff_avg = abs(results[self.config['tune']['discriminating_metric']] - self.ewma.get())
accu_diff_cons = abs(results[self.config['tune']['discriminating_metric']] - self.last_accu)
no_change_in_accu = 1 if accu_diff_avg < 0.0005 and accu_diff_cons < 0.002 and self.num_examples > 70000 else 0
self.ewma.update(results[self.config['tune']['discriminating_metric']])
self.last_accu = results[self.config['tune']['discriminating_metric']]
if self.max_accu < results[self.config['tune']['discriminating_metric']]:
self.max_accu = results[self.config['tune']['discriminating_metric']]
if self.checkpoint_best:
self.save_checkpoint('checkpoints', self.exp_name + '.pt')
print(f'{Fore.GREEN}New best model saved.{Style.RESET_ALL}')
self.exp.log_metric('max_accuracy', self.max_accu, step=self.num_examples, epoch=self.epoch)
training_results = {
self.config['tune']['discriminating_metric']: self.max_accu,
'num_examples': self.num_examples, 'no_change_in_accu' : no_change_in_accu}
return training_results
def _save(self, checkpoint_dir):
return self.save_checkpoint(checkpoint_dir, 'checkpoint_file.pt')
def save_checkpoint(self, checkpoint_dir, fname='checkpoint_file.pt'):
print(f'{Fore.CYAN}Saving model ...{Style.RESET_ALL}')
save_dict = {'model_state_dict': self.model.state_dict()}
for i, optimizer in enumerate(self.optimizers):
save_dict['op_' + str(i) + '_state_dict'] = optimizer.state_dict()
torch.save(save_dict, os.path.join(checkpoint_dir, fname))
return os.path.join(checkpoint_dir, fname)
def _restore(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint['model_state_dict'])
for i, optimizer in enumerate(self.optimizers):
optimizer.load_state_dict(checkpoint['op_' + str(i) + '_state_dict'])
def stop(self):
results, assets, image_fns = self.evaluator.eval_model(self.model, finished_training=True)
self.exp.log_metrics(results, step=self.num_examples, epoch=self.epoch)
[self.exp.log_asset_data(asset, step=self.num_examples) for asset in assets]
[self.exp.log_image(fn, step=self.num_examples) for fn in image_fns]
return super().stop()