forked from aildnont/covid-cxr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlime_explain.py
111 lines (92 loc) · 4.87 KB
/
lime_explain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from lime.lime_image import *
import pandas as pd
import yaml
import os
import datetime
import dill
import cv2
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from src.visualization.visualize import visualize_explanation
from src.predict import predict_instance, predict_and_explain
from src.data.preprocess import remove_text
def setup_lime():
'''
Load relevant information and create a LIME Explainer
:return: dict containing important information and objects for explanation experiments
'''
# Load relevant constants from project config file
cfg = yaml.full_load(open(os.getcwd() + "/config.yml", 'r'))
lime_dict = {}
lime_dict['NUM_SAMPLES'] = cfg['LIME']['NUM_SAMPLES']
lime_dict['NUM_FEATURES'] = cfg['LIME']['NUM_FEATURES']
lime_dict['IMG_PATH'] = cfg['PATHS']['IMAGES']
lime_dict['TEST_IMG_PATH'] = cfg['PATHS']['TEST_IMGS']
lime_dict['IMG_DIM'] = cfg['DATA']['IMG_DIM']
lime_dict['PRED_THRESHOLD'] = cfg['PREDICTION']['THRESHOLD']
lime_dict['CLASSES'] = cfg['DATA']['CLASSES']
lime_dict['CLASS_MODE'] = cfg['TRAIN']['CLASS_MODE']
lime_dict['COVID_ONLY'] = cfg['LIME']['COVID_ONLY']
KERNEL_WIDTH = cfg['LIME']['KERNEL_WIDTH']
FEATURE_SELECTION = cfg['LIME']['FEATURE_SELECTION']
# Load train and test sets
lime_dict['TRAIN_SET'] = pd.read_csv(cfg['PATHS']['TRAIN_SET'])
lime_dict['TEST_SET'] = pd.read_csv(cfg['PATHS']['TEST_SET'])
# Create ImageDataGenerator for test set
test_img_gen = ImageDataGenerator(preprocessing_function=remove_text,
samplewise_std_normalization=True, samplewise_center=True)
test_generator = test_img_gen.flow_from_dataframe(dataframe=lime_dict['TEST_SET'], directory=cfg['PATHS']['TEST_IMGS'],
x_col="filename", y_col='label_str', target_size=tuple(cfg['DATA']['IMG_DIM']), batch_size=1,
class_mode='categorical', shuffle=False)
lime_dict['TEST_GENERATOR'] = test_generator
# Define the LIME explainer
lime_dict['EXPLAINER'] = LimeImageExplainer(kernel_width=KERNEL_WIDTH, feature_selection=FEATURE_SELECTION,
verbose=True)
dill.dump(lime_dict['EXPLAINER'], open(cfg['PATHS']['LIME_EXPLAINER'], 'wb')) # Serialize the explainer
# Load trained model's weights
lime_dict['MODEL'] = load_model(cfg['PATHS']['MODEL_TO_LOAD'], compile=False)
test_predictions = lime_dict['MODEL'].predict_generator(test_generator, verbose=0)
return lime_dict
def explain_xray(lime_dict, idx, save_exp=True):
'''
# Make a prediction and provide a LIME explanation
:param lime_dict: dict containing important information and objects for explanation experiments
:param idx: index of image in test set to explain
:param save_exp: Boolean indicating whether to save the explanation visualization
'''
# Get i'th preprocessed image in test set
lime_dict['TEST_GENERATOR'].reset()
for i in range(idx + 1):
x, y = lime_dict['TEST_GENERATOR'].next()
x = np.squeeze(x, axis=0)
# Get the corresponding original image (no preprocessing)
orig_img = cv2.imread(lime_dict['TEST_IMG_PATH'] + lime_dict['TEST_SET']['filename'][idx])
new_dim = tuple(lime_dict['IMG_DIM'])
orig_img = cv2.resize(orig_img, new_dim, interpolation=cv2.INTER_NEAREST) # Resize image
# Make a prediction for this image and retrieve a LIME explanation for the prediction
start_time = datetime.datetime.now()
explanation, probs = predict_and_explain(x, lime_dict['MODEL'], lime_dict['EXPLAINER'],
lime_dict['NUM_FEATURES'], lime_dict['NUM_SAMPLES'])
print("Explanation time = " + str((datetime.datetime.now() - start_time).total_seconds()) + " seconds")
# Get image filename and label
img_filename = lime_dict['TEST_SET']['filename'][idx]
label = lime_dict['TEST_SET']['label'][idx]
# Rearrange prediction probability vector to reflect original ordering of classes in project config
probs = [probs[0][lime_dict['CLASSES'].index(c)] for c in lime_dict['TEST_GENERATOR'].class_indices]
# Visualize the LIME explanation and optionally save it to disk
if save_exp:
file_path = lime_dict['IMG_PATH']
else:
file_path = None
if lime_dict['COVID_ONLY'] == True:
label_to_see = lime_dict['TEST_GENERATOR'].class_indices['COVID-19']
else:
label_to_see = 'top'
_ = visualize_explanation(orig_img, explanation, img_filename, label, probs, lime_dict['CLASSES'], label_to_see=label_to_see,
file_path=file_path)
return
if __name__ == '__main__':
lime_dict = setup_lime()
i = 0 # Select i'th image in test set
explain_xray(lime_dict, i, save_exp=True) # Generate explanation for image