Skip to content

Commit

Permalink
Feature/#3 shap example (#10)
Browse files Browse the repository at this point in the history
* change layer to calculate shap #3

* visualize shap images #3

* add comment #3

* delete comment out #3

* use variable for num_epochs

* change the word Grad-CAM to SHAP #3

* use variable for ratio_num #3

* delete garbage collection #3

* change data_iterator to dataset #3

* rename folder and delete unnecessary files #3

* delete unnecessary files and images #3

* delete unnecessary argument #3

* split into precise functions #3

* add copyright #3

* delete unnecessary DS_Store #3

* fix lint error #3

* change readme #3

* change readme for explainable AI #3

* change image #3

* fix readme for shap #3

* change readme for nnabla-example #3

* fix comment #3

* fix readme for shap #3

* delete DSstore #3

* fix license #3

* delete unnecessary spaces and cells #3

* change the repository where ipynb file is in #3

* change the repository to reference from ghelia to sony #3

* clear the output of the first cell #3

* change readme #3

* deal gloabl variable as an argument #3

* deal error message as an exception handling #3

* delete unnecessary error message #3

* (#3) delete readme and ipynb

* (#3) fix readme line break

* add 50 images and ipynb file #3

* fix comment #3

* fix layer index #3

Co-authored-by: ohmorimori <[email protected]>
  • Loading branch information
twintrees and ohmorimori authored Jun 30, 2021
1 parent fbdabec commit d59ab8e
Show file tree
Hide file tree
Showing 53 changed files with 632 additions and 0 deletions.
460 changes: 460 additions & 0 deletions interactive-demos/shap.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added responsible_ai/shap/images/sample.png
172 changes: 172 additions & 0 deletions responsible_ai/shap/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2021 Sony Group Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import nnabla as nn
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap


def red_blue_map():
colors = []
for i in np.linspace(1, 0, 100):
colors.append((30. / 255, 136. / 255, 229. / 255, i))
for i in np.linspace(0, 1, 100):
colors.append((255. / 255, 13. / 255, 87. / 255, i))
return LinearSegmentedColormap.from_list("red_transparent_blue", colors)


def get_model_layers(model, inputs):
if len(inputs.shape) == 3:
batch_size = 1
else:
batch_size = len(inputs)

x = nn.Variable((batch_size,) + model.input_shape)
# set training True since gradient of variable is necessary for SHAP
model_with_inputs = model(x, training=True, returns_net=True)
model_layers = dict()
for k, v in model_with_inputs.variables.items():
if ("W" in k) or ("b" in k):
continue
else:
model_layers[k] = v

return model_layers


def gradient(model, inputs, idx, interim_layer_index):
model_layers = get_model_layers(model, inputs)

for v in model_layers.values():
v.grad.zero()
v.need_grad = True
input_layer = list(model_layers.values())[-1]
if interim_layer_index == 0:
layer = input_layer
else:
layer = list(model_layers.values())[interim_layer_index]
pred = list(model_layers.values())[-2]
selected = pred[:, idx]
input_layer.d = inputs
selected.forward()
selected.backward()
grad = layer.g.copy()

return grad


def get_interim_input(model, inputs, interim_layer_index):
model_layers = get_model_layers(model, inputs)

input_layer = list(model_layers.values())[-1]
input_layer.d = inputs
try:
middle_layer = list(model_layers.values())[interim_layer_index]
except IndexError:
print('The interim layer should be an integer between 1 and the number of layers of the model!')
pred = list(model_layers.values())[-2]
pred.forward()
middle_layer_d = middle_layer.d.copy()

return middle_layer_d


def shap(model, X, label, interim_layer_index, num_samples,
dataset, batch_size, num_epochs=1):
# get data
if interim_layer_index == 0:
data = X.reshape((1,) + X.shape)
else:
data = get_interim_input(model, X, interim_layer_index)

samples_input = [np.zeros((num_samples, ) + X.shape)]
samples_delta = [np.zeros((num_samples, ) + data.shape[1:])]

rseed = np.random.randint(0, 1e6)
np.random.seed(rseed)
phis = [np.zeros((1,) + data.shape[1:])]

output_phis = []

for j in range(num_epochs):
for k in range(num_samples):
rind = np.random.choice(len(dataset))
t = np.random.uniform()
im = dataset[rind]
x = X.copy()
samples_input[j][k] = (t * x + (1 - t) * im.copy()).copy()
if interim_layer_index == 0:
samples_delta[j][k] = (x - im.copy()).copy()
else:
samples_delta[j][k] = get_interim_input(
model, samples_input[j][k],
interim_layer_index)[0]

grads = []

for b in range(0, num_samples, batch_size):
batch_last = min(b + batch_size, num_samples)
batch = samples_input[j][b:batch_last].copy()
grads.append(gradient(model, batch, label, interim_layer_index))
grad = [np.concatenate([g for g in grads], 0)]
samples = grad[0] * samples_delta[0]
phis[0][j] = samples.mean(0)

output_phis.append(phis[0])
return output_phis


def visualize(X, output_phis, output, ratio_num=10):
img = X.copy()
height = img.shape[1]
width = img.shape[2]
ratio = ratio_num / height
fig_size = np.array([width * ratio, ratio_num])
fig, ax = plt.subplots(figsize=fig_size, dpi=1 / ratio)
shap_plot = output_phis[0][0].sum(0)

if img.max() > 1:
img = img / 255.
if img.shape[0] == 3:
img_gray = (0.2989 * img[0, :, :] +
0.5870 * img[1, :, :] + 0.1140 * img[2, :, :])
else:
img_gray = img.reshape(img.shape[1:])

abs_phis = np.abs(output_phis[0].sum(1)).flatten()
max_border = np.nanpercentile(abs_phis, 99.9)
min_border = -np.nanpercentile(abs_phis, 99.9)

ax.imshow(img_gray, cmap=plt.get_cmap('gray'), alpha=0.15,
extent=(-1, shap_plot.shape[1], shap_plot.shape[0], -1))
im = ax.imshow(shap_plot, cmap=red_blue_map(),
vmin=min_border, vmax=max_border)
ax.axis("off")

fig.tight_layout()

fig.savefig(output)
fig.clf()
plt.close()


def shap_computation(model, X, label, interim_layer_index, num_samples,
dataset, batch_size, output):
output_phis = shap(model, X, label, interim_layer_index, num_samples,
dataset, batch_size)
visualize(X, output_phis, output)

0 comments on commit d59ab8e

Please sign in to comment.