-
Notifications
You must be signed in to change notification settings - Fork 181
/
tstr.py
58 lines (49 loc) · 1.7 KB
/
tstr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env ipython
# Run TSTR on a trained model. (helper script)
import sys
import glob
import numpy as np
import pdb
from eval import TSTR_mnist, TSTR_eICU
assert len(sys.argv) >= 2
identifier = sys.argv[1]
print(identifier)
model = sys.argv[2]
if model == 'CNN':
CNN = True
print('Using CNN')
else:
CNN = False
print('Using RF')
task = sys.argv[3]
if task == 'mnist':
mnist = True
print('testing on mnist')
else:
mnist = False
print('testing on eicu')
params_dir = 'REDACTED'
params = glob.glob(params_dir + identifier + '_*.npy')
print(params)
epochs = [int(p.split('_')[-1].strip('.npy')) for p in params]
# (I write F1 here but we're not actually reporting the F1, sorry :/)
epoch_f1 = np.zeros(len(epochs))
print('Running TSTR on validation set across all epochs for which parameters are available')
for (i, e) in enumerate(epochs):
if mnist:
synth_f1, real_f1 = TSTR_mnist(identifier, e, generate=True, vali=True, CNN=CNN)
else:
print('testing eicu')
synth_f1 = TSTR_eICU(identifier, e, generate=True, vali=True, CNN=CNN)
epoch_f1[i] = synth_f1
best_epoch_index = np.argmax(epoch_f1)
best_epoch = epochs[best_epoch_index]
print('Running TSTR on', identifier, 'at epoch', best_epoch, '(validation f1 was', epoch_f1[best_epoch_index], ')')
if mnist:
TSTR_mnist(identifier, best_epoch, generate=True, vali=False, CNN=CNN)
# also run TRTS at that epoch
TSTR_mnist(identifier, best_epoch, generate=True, vali=False, CNN=CNN, reverse=True)
else:
TSTR_eICU(identifier, best_epoch, generate=True, vali=False, CNN=CNN)
# also run TRTS at that epoch
TSTR_eICU(identifier, best_epoch, generate=True, vali=False, CNN=CNN, reverse=True)