forked from rish-16/cs4243-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tuner.py
93 lines (70 loc) · 2.97 KB
/
tuner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch, copy, random
import matplotlib.pylot as plt
class Tuner:
def __init__(self, config, trainer):
'''
config is a dictionary of (hparam:values) pairs
that need to be tuned.
model/model2 is the model class that
hasn't been instantiated yet.
'''
self.config = config
self.trainer = trainer
def set_other_hparams(self, temp_config):
'''
Takes in the config and randomly choses
the other hparams to be picked.
'''
others = {}
for hparam, values in temp_config.items():
others[hparam] = values[0]
return others
def tune(self):
config_cp = copy.deepcopy(self.config)
final_set = {}
for hparam, _ in config_cp.items():
final_set[hparam] = None
all_histories = {}
for hparam, values in self.config.items():
print ("Currently tuning : ", hparam)
config_cp.pop(hparam)
best_acc = float('-inf')
best_choice = None
if best_choice == None:
current_hparams = self.set_other_hparams(config_cp)
else:
current_hparams = self.set_other_hparams(config_cp)
# set the existing best hparams into the new config
for done_hparam, val in final_set.items():
current_hparams[done_hparam] = val
hparam_history = []
for choice in values:
current_hparams[hparam] = choice
print ("Complete list: ", current_hparams)
val_acc, avg_loss, history = self.trainer.fit(current_hparams, verbose=False, return_history=True)
hparam_history[choice] = {
"history": history,
"hparam_setup": current_hparams
}
if val_acc > best_acc:
best_acc = val_acc
best_choice = choice
final_set[hparam] = best_choice
all_histories[hparam] = hparam_history
return final_set, all_histories
def ablate_hparams_val_accs(self, all_histories):
for hparam, hparam_history in all_histories.items():
for choice, metadata in hparam_history.items():
training_history = metadata['history']
plt.plot(metadata['epochs'], metadata['val_acc'], label=repr(choice))
plt.legend()
plt.title("Sensitivity of {} on Validation Accuracy".format(hparam))
plt.show()
def ablate_hparam_train_losses(self, all_histories):
for hparam, hparam_history in all_histories.items():
for choice, metadata in hparam_history.items():
training_history = metadata['history']
plt.plot(metadata['epochs'], metadata['train_loss'], label=repr(choice))
plt.legend()
plt.title("Sensitivity of {} on Validation Accuracy".format(hparam))
plt.show()