Skip to content

Commit

Permalink
Merge pull request #84 from aertslab/tf_explainer_multi_output_improv…
Browse files Browse the repository at this point in the history
…ement

speed improvement for calculating (expected_)integrated_grad for multiple outputs at the same time.
  • Loading branch information
LukasMahieu authored Feb 13, 2025
2 parents f457204 + 44dcf89 commit 2f7438c
Showing 1 changed file with 46 additions and 10 deletions.
56 changes: 46 additions & 10 deletions src/crested/tl/_explainer_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Adapted from: https://github.com/p-koo/tfomics/blob/master/tfomics/
"""

from __future__ import annotations

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -59,7 +61,7 @@ def integrated_grad(self, X, baseline_type="random", num_steps=25):
func=self.func,
)
scores.append(intgrad_scores)
return np.concatenate(scores, axis=0)
return np.array(scores)

def expected_integrated_grad(
self, X, num_baseline=25, baseline_type="random", num_steps=25
Expand All @@ -78,7 +80,7 @@ def expected_integrated_grad(
func=self.func,
)
scores.append(intgrad_scores)
return np.concatenate(scores, axis=0)
return np.array(scores)

def mutagenesis(self, X, class_index=None):
"""In silico mutagenesis analysis for a given sequence."""
Expand All @@ -97,18 +99,38 @@ def set_baseline(self, x, baseline, num_samples):
return baseline


def saliency_map(X, model, class_index=None, func=tf.math.reduce_mean):
def saliency_map(
X, model, class_index: int | list[int] | None = None, func=tf.math.reduce_mean
):
"""Fast function to generate saliency maps."""
if not tf.is_tensor(X):
X = tf.Variable(X)

with tf.GradientTape() as tape:
# use persistent tape so gradient can be calculated for each output in class_index, in case
# class_index is a list of indexes.
with tf.GradientTape(persistent=True) as tape:
tape.watch(X)
if class_index is not None:
outputs = model(X)[:, class_index]
output = model(X)
if isinstance(class_index, int):
# get output for class (C)
outputs_C = [output[:, class_index]]
elif isinstance(class_index, list):
# get output for multiple classes
outputs_C = [output[:, c] for c in class_index]
elif class_index is None:
# legacy mode -- not sure if `func` is even needed here??
outputs_C = [func(model(X))]
else:
outputs = func(model(X))
return tape.gradient(outputs, X)
raise ValueError(
f"class_index should be either an integer a list of integers or None, not: {class_index}."
)
grads = np.empty((len(outputs_C), *X.shape))
for i in range(len(outputs_C)):
grads[i] = tape.gradient(outputs_C[i], X)
# explicitly delete the tape, needed because persistent is True
del tape
# squeeze grads so first dimension is dropped in case class_index is a single int.
return grads.squeeze()


@tf.function
Expand Down Expand Up @@ -148,7 +170,12 @@ def smoothgrad(


def integrated_grad(
x, model, baseline, num_steps=25, class_index=None, func=tf.math.reduce_mean
x,
model,
baseline,
num_steps=25,
class_index: int | list[int] | None = None,
func=tf.math.reduce_mean,
):
"""Calculate integrated gradients for a given sequence."""

Expand All @@ -167,8 +194,17 @@ def interpolate_data(baseline, x, steps):
steps = tf.linspace(start=0.0, stop=1.0, num=num_steps + 1)
x_interp = interpolate_data(baseline, x, steps)
grad = saliency_map(x_interp, model, class_index=class_index, func=func)
# at this point the shape of grad is either:
# - (num_steps + 1, *x.shape) in case class_index is None or a single int.
# - (len(class_index), num_steps + 1, *x.shape) in case class_idnex is a list of int.
if len(grad.shape) == 4:
# second case, put num_steps + 1 on first axis
grad = grad.swapaxes(0, 1)
avg_grad = integral_approximation(grad)
avg_grad = np.expand_dims(avg_grad, axis=0)
if len(avg_grad.shape) != 3:
# first case, in this case the dimension should be expanded.
# in the second case they are already expanded.
avg_grad = np.expand_dims(avg_grad, axis=0)
return avg_grad


Expand Down

0 comments on commit 2f7438c

Please sign in to comment.