-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdemo_KMplot.py
59 lines (50 loc) · 2.03 KB
/
demo_KMplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from lifelines import KaplanMeierFitter
from survival4D.paths import DATA_DIR
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"-d", "--data-dir", dest="data_dir", type=str, default=None, help="Directory where the data file is."
)
parser.add_argument(
"-f", "--file-name", dest="file_name", type=str, default="bootout_conv.pkl", help="Data file name."
)
return parser.parse_args()
def main():
args = parse_args()
if args.data_dir is None:
data_dir = DATA_DIR
else:
data_dir = Path(args.data_dir)
with open(str(data_dir.joinpath(args.file_name)), 'rb') as f:
inputdata_list = pickle.load(f)
y_orig = inputdata_list[0]
preds_bootfull = inputdata_list[1]
inds_inbag = inputdata_list[2]
del inputdata_list
preds_bootfull_mat = np.concatenate(preds_bootfull, axis=1)
inds_inbag_mat = np.array(inds_inbag).T
inbag_mask = 1*np.array([np.any(inds_inbag_mat==_, axis=0) for _ in range(inds_inbag_mat.shape[0])])
preds_bootave_oob = np.divide(np.sum(np.multiply((1-inbag_mask), preds_bootfull_mat), axis=1), np.sum(1-inbag_mask, axis=1))
risk_groups = 1*(preds_bootave_oob > np.median(preds_bootave_oob))
wdf = pd.DataFrame(
np.concatenate((y_orig, preds_bootave_oob[:, np.newaxis],risk_groups[:, np.newaxis]), axis=-1),
columns=['status', 'time', 'preds', 'risk_groups'], index=[str(_) for _ in risk_groups]
)
kmf = KaplanMeierFitter()
ax = plt.subplot(111)
kmf.fit(durations=wdf.loc['0','time'], event_observed=wdf.loc['0','status'], label="Low Risk")
ax = kmf.plot(ax=ax)
kmf.fit(durations=wdf.loc['1','time'], event_observed=wdf.loc['1','status'], label="High Risk")
ax = kmf.plot(ax=ax)
plt.ylim(0,1)
plt.title("Kaplan-Meier Plots")
plt.xlabel('Time (days)')
plt.ylabel('Survival Probability')
if __name__ == '__main__':
main()