-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrad_cam.py
176 lines (138 loc) · 5.56 KB
/
grad_cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Core Module for Grad CAM Algorithm
"""
import cv2
import numpy as np
import tensorflow as tf
from tf_explain.utils.display import grid_display, heatmap_display
from tf_explain.utils.saver import save_rgb
class GradCAM:
"""
Perform Grad CAM algorithm for a given input
Paper: [Grad-CAM: Visual Explanations from Deep Networks
via Gradient-based Localization](https://arxiv.org/abs/1610.02391)
"""
def explain(
self,
validation_data,
model,
class_index,
layer_name=None,
colormap=cv2.COLORMAP_VIRIDIS,
):
"""
Compute GradCAM for a specific class index.
Args:
validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data
to perform the method on. Tuple containing (x, y).
model (tf.keras.Model): tf.keras model to inspect
class_index (int): Index of targeted class
layer_name (str): Targeted layer for GradCAM. If no layer is provided, it is
automatically infered from the model architecture.
colormap (int): OpenCV Colormap to use for heatmap visualization
Returns:
numpy.ndarray: Grid of all the GradCAM
"""
images, _ = validation_data
if layer_name is None:
layer_name = self.infer_grad_cam_target_layer(model)
outputs, guided_grads = GradCAM.get_gradients_and_filters(
model, images, layer_name, class_index
)
cams = GradCAM.generate_ponderated_output(outputs, guided_grads)
heatmaps = np.array(
[
heatmap_display(cam.numpy(), image, colormap)
for cam, image in zip(cams, images)
]
)
grid = grid_display(heatmaps)
return grid, outputs, guided_grads
@staticmethod
def infer_grad_cam_target_layer(model):
"""
Search for the last convolutional layer to perform Grad CAM, as stated
in the original paper.
Args:
model (tf.keras.Model): tf.keras model to inspect
Returns:
str: Name of the target layer
"""
for layer in reversed(model.layers):
# Select closest 4D layer to the end of the network.
if len(layer.output_shape) == 4:
return layer.name
raise ValueError(
"Model does not seem to contain 4D layer. Grad CAM cannot be applied."
)
@staticmethod
@tf.function
def get_gradients_and_filters(model, images, layer_name, class_index):
"""
Generate guided gradients and convolutional outputs with an inference.
Args:
model (tf.keras.Model): tf.keras model to inspect
images (numpy.ndarray): 4D-Tensor with shape (batch_size, H, W, 3)
layer_name (str): Targeted layer for GradCAM
class_index (int): Index of targeted class
Returns:
Tuple[tf.Tensor, tf.Tensor]: (Target layer outputs, Guided gradients)
"""
grad_model = tf.keras.models.Model(
[model.inputs], [model.get_layer(layer_name).output, model.output]
)
with tf.GradientTape() as tape:
inputs = tf.cast(images, tf.float32)
conv_outputs, predictions = grad_model(inputs)
loss = predictions[:, class_index]
grads = tape.gradient(loss, conv_outputs)
guided_grads = (
tf.cast(conv_outputs > 0, "float32") * tf.cast(grads > 0, "float32") * grads
)
return conv_outputs, guided_grads
@staticmethod
def generate_ponderated_output(outputs, grads):
"""
Apply Grad CAM algorithm scheme.
Inputs are the convolutional outputs (shape WxHxN) and gradients (shape WxHxN).
From there:
- we compute the spatial average of the gradients
- we build a ponderated sum of the convolutional outputs based on those averaged weights
Args:
output (tf.Tensor): Target layer outputs, with shape (batch_size, Hl, Wl, Nf),
where Hl and Wl are the target layer output height and width, and Nf the
number of filters.
grads (tf.Tensor): Guided gradients with shape (batch_size, Hl, Wl, Nf)
Returns:
List[tf.Tensor]: List of ponderated output of shape (batch_size, Hl, Wl, 1)
"""
maps = [
GradCAM.ponderate_output(output, grad)
for output, grad in zip(outputs, grads)
]
return maps
@staticmethod
def ponderate_output(output, grad):
"""
Perform the ponderation of filters output with respect to average of gradients values.
Args:
output (tf.Tensor): Target layer outputs, with shape (Hl, Wl, Nf),
where Hl and Wl are the target layer output height and width, and Nf the
number of filters.
grads (tf.Tensor): Guided gradients with shape (Hl, Wl, Nf)
Returns:
tf.Tensor: Ponderated output of shape (Hl, Wl, 1)
"""
weights = tf.reduce_mean(grad, axis=(0, 1))
# Perform ponderated sum : w_i * output[:, :, i]
cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
return cam
def save(self, grid, output_dir, output_name):
"""
Save the output to a specific dir.
Args:
grid (numpy.ndarray): Grid of all the heatmaps
output_dir (str): Output directory path
output_name (str): Output name
"""
save_rgb(grid, output_dir, output_name)