Skip to content

Commit

Permalink
Feature/#3 shap example (#11)
Browse files Browse the repository at this point in the history
* change layer to calculate shap #3

* visualize shap images #3

* add comment #3

* delete comment out #3

* use variable for num_epochs

* change the word Grad-CAM to SHAP #3

* use variable for ratio_num #3

* delete garbage collection #3

* change data_iterator to dataset #3

* rename folder and delete unnecessary files #3

* delete unnecessary files and images #3

* delete unnecessary argument #3

* split into precise functions #3

* add copyright #3

* delete unnecessary DS_Store #3

* fix lint error #3

* change readme #3

* change readme for explainable AI #3

* change image #3

* fix readme for shap #3

* change readme for nnabla-example #3

* fix comment #3

* fix readme for shap #3

* delete DSstore #3

* fix license #3

* delete unnecessary spaces and cells #3

* change the repository where ipynb file is in #3

* change the repository to reference from ghelia to sony #3

* clear the output of the first cell #3

* change readme #3

* deal gloabl variable as an argument #3

* deal error message as an exception handling #3

* delete unnecessary error message #3

* (#3) delete readme and ipynb

* (#3) fix readme line break

* add 50 images and ipynb file #3

* fix comment #3

* fix layer index #3

* (#3) fix redundunt processing

* (#3) fix to designate weight and bias layers with slash in order not to allow mistakes

* (#3) zip image to use images by unzipping in notebook

* (#3) fix to import shap at the beginning

* (#3) delete unnecessary part

* (#3) fix explanation

* (#3) run with num_samples=100

Co-authored-by: twintrees <[email protected]>
Co-authored-by: twintrees <[email protected]>
  • Loading branch information
3 people authored Jun 30, 2021
1 parent d59ab8e commit ee945c5
Show file tree
Hide file tree
Showing 53 changed files with 126 additions and 118 deletions.
164 changes: 94 additions & 70 deletions interactive-demos/shap.ipynb

Large diffs are not rendered by default.

Binary file added responsible_ai/shap/imagenet50.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed responsible_ai/shap/imagenet50/sim_n03049924_13.jpg
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
80 changes: 32 additions & 48 deletions responsible_ai/shap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import numpy as np
import nnabla as nn
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

Expand All @@ -30,69 +28,58 @@ def red_blue_map():
return LinearSegmentedColormap.from_list("red_transparent_blue", colors)


def get_model_layers(model, inputs):
if len(inputs.shape) == 3:
batch_size = 1
else:
batch_size = len(inputs)

x = nn.Variable((batch_size,) + model.input_shape)
# set training True since gradient of variable is necessary for SHAP
model_with_inputs = model(x, training=True, returns_net=True)
def get_model_layers(model_graph):
model_layers = dict()
for k, v in model_with_inputs.variables.items():
if ("W" in k) or ("b" in k):
for k, v in model_graph.variables.items():
if ("/W" in k) or ("/b" in k):
continue
else:
model_layers[k] = v

return model_layers


def gradient(model, inputs, idx, interim_layer_index):
model_layers = get_model_layers(model, inputs)
def get_middle_layer(model_layers, interim_layer_index):
return list(model_layers.values())[interim_layer_index]


def gradient(model_layers, input_layer, target_layer, output_layer, inputs, idx):

for v in model_layers.values():
v.grad.zero()
v.need_grad = True
input_layer = list(model_layers.values())[-1]
if interim_layer_index == 0:
layer = input_layer
else:
layer = list(model_layers.values())[interim_layer_index]
pred = list(model_layers.values())[-2]
selected = pred[:, idx]
selected = output_layer[:, idx]
input_layer.d = inputs
selected.forward()
selected.backward()
grad = layer.g.copy()

grad = target_layer.g.copy()
return grad


def get_interim_input(model, inputs, interim_layer_index):
model_layers = get_model_layers(model, inputs)

input_layer = list(model_layers.values())[-1]
def get_interim_input(input_layer, middle_layer, output_layer, inputs):
input_layer.d = inputs
try:
middle_layer = list(model_layers.values())[interim_layer_index]
except IndexError:
print('The interim layer should be an integer between 1 and the number of layers of the model!')
pred = list(model_layers.values())[-2]
pred.forward()
output_layer.forward()
middle_layer_d = middle_layer.d.copy()

return middle_layer_d


def shap(model, X, label, interim_layer_index, num_samples,
def shap(model_graph, X, label, interim_layer_index, num_samples,
dataset, batch_size, num_epochs=1):
input_layer = list(model_graph.inputs.values())[0]
model_layers = get_model_layers(model_graph)
middle_layer = get_middle_layer(model_layers, interim_layer_index)
output_layer = list(model_graph.outputs.values())[0]
# get data
if len(X.shape) == 3:
batch_size = 1
else:
batch_size = len(X)

x = nn.Variable((batch_size,) + input_layer.shape)
# set training True since gradient of variable is necessary for SHAP
if interim_layer_index == 0:
data = X.reshape((1,) + X.shape)
else:
data = get_interim_input(model, X, interim_layer_index)
data = get_interim_input(input_layer, middle_layer, output_layer, X)

samples_input = [np.zeros((num_samples, ) + X.shape)]
samples_delta = [np.zeros((num_samples, ) + data.shape[1:])]
Expand All @@ -113,16 +100,14 @@ def shap(model, X, label, interim_layer_index, num_samples,
if interim_layer_index == 0:
samples_delta[j][k] = (x - im.copy()).copy()
else:
samples_delta[j][k] = get_interim_input(
model, samples_input[j][k],
interim_layer_index)[0]
samples_delta[j][k] = get_interim_input(input_layer, middle_layer, output_layer, samples_input[j][k])[0]

grads = []

for b in range(0, num_samples, batch_size):
batch_last = min(b + batch_size, num_samples)
batch = samples_input[j][b:batch_last].copy()
grads.append(gradient(model, batch, label, interim_layer_index))
grads.append(gradient(model_layers, input_layer, middle_layer, output_layer, batch, label))
grad = [np.concatenate([g for g in grads], 0)]
samples = grad[0] * samples_delta[0]
phis[0][j] = samples.mean(0)
Expand All @@ -132,6 +117,8 @@ def shap(model, X, label, interim_layer_index, num_samples,


def visualize(X, output_phis, output, ratio_num=10):
import matplotlib
matplotlib.use('Agg')
img = X.copy()
height = img.shape[1]
width = img.shape[2]
Expand All @@ -154,19 +141,16 @@ def visualize(X, output_phis, output, ratio_num=10):

ax.imshow(img_gray, cmap=plt.get_cmap('gray'), alpha=0.15,
extent=(-1, shap_plot.shape[1], shap_plot.shape[0], -1))
im = ax.imshow(shap_plot, cmap=red_blue_map(),
vmin=min_border, vmax=max_border)
ax.imshow(shap_plot, cmap=red_blue_map(), vmin=min_border, vmax=max_border)
ax.axis("off")

fig.tight_layout()

fig.savefig(output)
fig.clf()
plt.close()


def shap_computation(model, X, label, interim_layer_index, num_samples,
def shap_computation(model_graph, X, label, interim_layer_index, num_samples,
dataset, batch_size, output):
output_phis = shap(model, X, label, interim_layer_index, num_samples,
output_phis = shap(model_graph, X, label, interim_layer_index, num_samples,
dataset, batch_size)
visualize(X, output_phis, output)

0 comments on commit ee945c5

Please sign in to comment.