-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_with_cv.py
227 lines (184 loc) · 10.5 KB
/
train_with_cv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import argparse
import collections
import copy
import time
import pickle
import pandas as pd
from pathlib import Path
import numpy as np
from sklearn.preprocessing import LabelEncoder
import global_config
from data_loader.ecg_data_set import ECGDataset
from logger import update_logging_setup_for_tune_or_cross_valid
from parse_config import ConfigParser
from test import test_model
from train import train_model, _set_seed
from utils import ensure_dir
# Needed for working with SSH Interpreter...
import os
os.environ["CUDA_VISIBLE_DEVICES"] = global_config.CUDA_VISIBLE_DEVICES
global_config.suppress_warnings()
def test_fold(config, data_dir, test_idx, k_fold, total_num_folds):
df_class_wise_results, df_single_metric_results = test_model(config, tune_config=None, cv_active=True,
cv_data_dir=data_dir, test_idx=test_idx,
k_fold=k_fold, total_num_folds=total_num_folds)
return df_class_wise_results, df_single_metric_results
def train_fold(config, train_idx, valid_idx, k_fold, total_num_folds):
log_best = train_model(config, train_idx=train_idx, valid_idx=valid_idx, k_fold=k_fold,
total_num_folds=total_num_folds, cv_active=True,
tune_config=None, train_dl=None, valid_dl=None, checkpoint_dir=None, use_tune=False)
return log_best
def get_new_labels(y):
""" Convert each multilabel vector to a unique string """
yy = [''.join(str(l)) for l in y]
y_new = LabelEncoder().fit_transform(yy)
return y_new
def run_cross_validation(config):
total_num_folds = config["data_loader"]["cross_valid"]["k_fold"]
data_dir = config["data_loader"]["cross_valid"]["data_dir"]
assert "PTB_XL" not in data_dir, "Cross validation not yet implemented for PTB-XL!"
# Update data dir in Dataloader ARGs!
config["data_loader"]["args"]["data_dir"] = data_dir
dataset = ECGDataset(data_dir)
n_samples = len(dataset)
idx_full = np.arange(n_samples)
np.random.shuffle(idx_full)
# Get the main config and the dirs for logging and saving checkpoints
base_config = copy.deepcopy(config)
base_log_dir = config.log_dir
base_save_dir = config.save_dir
# Divide the samples into k distinct sets
fold_size = n_samples // total_num_folds
fold_data = []
lower_limit = 0
for i in range(0, total_num_folds):
if i != total_num_folds - 1:
fold_data.append(idx_full[lower_limit:lower_limit + fold_size])
lower_limit = lower_limit + fold_size
else:
# Last fold may be a bit larger
fold_data.append(idx_full[lower_limit:n_samples])
# Run k-fold-cross-validation
# Each time, one of the subset functions as valid and another as test set
# Save the results of each run
class_wise_metrics = ["precision", "recall", "f1-score", "torch_roc_auc", "torch_accuracy", "support"]
folds = ['fold_' + str(i) for i in range(1, total_num_folds + 1)]
iterables = [folds, class_wise_metrics]
multi_index = pd.MultiIndex.from_product(iterables, names=["fold", "metric"])
valid_results = pd.DataFrame(columns=folds)
test_results_class_wise = pd.DataFrame(columns=['SNR', 'AF', 'IAVB', 'LBBB', 'RBBB', 'PAC', 'VEB', 'STD', 'STE',
'macro avg', 'weighted avg'], index=multi_index)
test_results_single_metrics = pd.DataFrame(columns=['loss', 'sk_subset_accuracy'])
print("Starting with " + str(total_num_folds) + "-fold cross validation")
start = time.time()
valid_fold_index = total_num_folds - 2
test_fold_index = total_num_folds - 1
for k in range(total_num_folds):
# Get the idx for train, valid and test samples
train_sets = [fold for id, fold in enumerate(fold_data)
if id != valid_fold_index and id != test_fold_index]
train_idx = np.concatenate(train_sets, axis=0)
valid_idx = fold_data[valid_fold_index]
test_idx = fold_data[test_fold_index]
print("Starting fold " + str(k))
print("Valid Set: " + str(valid_fold_index) + ", Test Set: " + str(test_fold_index))
# Adapt the log and save paths for the current fold
config.save_dir = Path(os.path.join(base_save_dir, "Fold_" + str(k + 1)))
config.log_dir = Path(os.path.join(base_log_dir, "Fold_" + str(k + 1)))
ensure_dir(config.save_dir)
ensure_dir(config.log_dir)
update_logging_setup_for_tune_or_cross_valid(config.log_dir)
# Write record names to pickle for reproducing single folds
dict = {
"train_records": np.array(dataset.records)[train_idx],
"valid_records": np.array(dataset.records)[valid_idx],
"test_records": np.array(dataset.records)[test_idx]
}
with open(os.path.join(config.save_dir, "data_split.csv"), "w") as file:
pd.DataFrame.from_dict(data=dict, orient='index').to_csv(file, header=False)
# Do the training and add the fold results to the df
fold_train_model_best = train_fold(config, train_idx=train_idx, valid_idx=valid_idx, k_fold=k,
total_num_folds=total_num_folds)
valid_results[folds[k]] = fold_train_model_best
# Do the testing and add the fold results to the dfs
config.resume = Path(os.path.join(config.save_dir, "model_best.pth"))
config.test_output_dir = Path(os.path.join(config.resume.parent, 'test_output_fold_' + str(k + 1)))
ensure_dir(config.test_output_dir)
fold_eval_class_wise, fold_eval_single_metrics = test_fold(config, data_dir=data_dir, test_idx=test_idx,
k_fold=k, total_num_folds=total_num_folds)
# Class-Wise Metrics
test_results_class_wise.loc[(folds[k], fold_eval_class_wise.index), fold_eval_class_wise.columns] = \
fold_eval_class_wise.values
# Single Metrics
pd_series = fold_eval_single_metrics.loc['value']
pd_series.name = folds[k]
test_results_single_metrics = test_results_single_metrics.append(pd_series)
# Update the indices and reset the config (including resume!)
valid_fold_index = (valid_fold_index + 1) % (total_num_folds)
test_fold_index = (test_fold_index + 1) % (total_num_folds)
config = copy.deepcopy(base_config)
# Summarize the results of the cross validation and write everything to files
# --------------------------- Test Class-Wise Metrics ---------------------------
iterables_summary = [["mean", "median"], class_wise_metrics]
multi_index = pd.MultiIndex.from_product(iterables_summary, names=["merging", "metric"])
test_results_class_wise_summary = pd.DataFrame(
columns=['SNR', 'AF', 'IAVB', 'LBBB', 'RBBB', 'PAC', 'VEB', 'STD', 'STE',
'macro avg', 'weighted avg'], index=multi_index)
for metric in class_wise_metrics:
test_results_class_wise_summary.loc[('mean', metric)] = test_results_class_wise.xs(metric, level=1).mean()
test_results_class_wise_summary.loc[('median', metric)] = test_results_class_wise.xs(metric, level=1).median()
path = os.path.join(base_save_dir, "test_results_class_wise.p")
with open(path, 'wb') as file:
pickle.dump(test_results_class_wise, file)
path = os.path.join(base_save_dir, "test_results_class_wise_summary.p")
with open(path, 'wb') as file:
pickle.dump(test_results_class_wise_summary, file)
# --------------------------- Test Single Metrics ---------------------------
test_results_single_metrics.loc['mean'] = test_results_single_metrics.mean()
test_results_single_metrics.loc['median'] = test_results_single_metrics[:][:-1].median()
path = os.path.join(base_save_dir, "test_results_single_metrics.p")
with open(path, 'wb') as file:
pickle.dump(test_results_single_metrics, file)
# --------------------------- Train Result ---------------------------
valid_results['mean'] = valid_results.mean(axis=1)
valid_results['median'] = valid_results.median(axis=1)
path = os.path.join(base_save_dir, "cross_validation_valid_results.p")
with open(path, 'wb') as file:
pickle.dump(valid_results, file)
# --------------------------- Test Metrics To Latex---------------------------
with open(os.path.join(base_save_dir, 'cross_validation_results.tex'), 'w') as file:
file.write("Class-Wise Summary:\n\n")
test_results_class_wise_summary.to_latex(buf=file, index=False, float_format="{:0.3f}".format,
escape=False)
file.write("\n\n\nSingle Metrics:\n\n")
test_results_single_metrics.to_latex(buf=file, index=False, float_format="{:0.3f}".format,
escape=False)
# Finish everything
end = time.time()
ty_res = time.gmtime(end - start)
res = time.strftime("%H hours, %M minutes, %S seconds", ty_res)
print("Finished cross-fold-validation")
print("Consuming time: " + str(res))
if __name__ == '__main__':
args = argparse.ArgumentParser(description='MACRO Paper: Cross-Validation')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('-t', '--tune', action='store_true', help='Use to enable tuning')
args.add_argument('--seed', type=int, default=123, help='Random seed')
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
# options added here can be modified by command line flags.
]
config = ConfigParser.from_args(args=args, options=options)
assert config["data_loader"]["cross_valid"]["enabled"], "Cross-valid should be enabled when running this script"
# fix random seeds for reproducibility
global_config.SEED = config.config.get("SEED", global_config.SEED)
_set_seed(global_config.SEED)
run_cross_validation(config)