Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched metrics - Continuation #357

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4f636de
rest of the implementation start
Aug 18, 2024
ad03d35
added sufficiency
davor10105 Aug 18, 2024
06b07c1
added sensitivity and road (not sure if completely done)
davor10105 Aug 18, 2024
a953a7c
added irof
davor10105 Aug 20, 2024
0f75eb2
added more metrics
davor10105 Aug 24, 2024
002a79c
added all robustness metrics
davor10105 Aug 24, 2024
21a410a
added all localization metrics
davor10105 Aug 31, 2024
5ddab45
added complexity and axiomatic (ask about non-sensitivity)
davor10105 Sep 1, 2024
4bae057
added randomization metrics
davor10105 Sep 14, 2024
4c2374d
merging latest changes from batched-metrics
davor10105 Nov 13, 2024
83649fd
added all metrics
davor10105 Nov 15, 2024
9f6e82f
removing unneeded functions
davor10105 Nov 15, 2024
6f2a7b5
add batch randomization to init
davor10105 Nov 15, 2024
cab2b6b
removing non-batched implementations
davor10105 Nov 15, 2024
23f3911
added fixes for robustness metrics
davor10105 Nov 17, 2024
e913248
resolved tests for robustness and localization
davor10105 Nov 27, 2024
eb99beb
resolving the remaining tests
davor10105 Nov 27, 2024
1282ff9
adding nan return condition to continuity
davor10105 Nov 27, 2024
4a5d29e
remove unused imports
davor10105 Nov 28, 2024
4d2d59e
removed unused param
davor10105 Nov 28, 2024
d885178
Merge branch 'understandable-machine-intelligence-lab:main' into batc…
davor10105 Nov 28, 2024
19e99b3
resolving lint tests
davor10105 Nov 28, 2024
c5a335d
Merge branch 'batched-metrics' of https://github.com/davor10105/Quant…
davor10105 Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 12 additions & 28 deletions quantus/functions/mosaic_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def build_single_mosaic(mosaic_images_list: List[np.ndarray]) -> np.ndarray:
mosaic: np.ndarray
The single 2x2 mosaic built from a list of images.
"""
first_row = np.concatenate((mosaic_images_list[0], mosaic_images_list[1]), axis=1)
second_row = np.concatenate((mosaic_images_list[2], mosaic_images_list[3]), axis=1)
mosaic = np.concatenate((first_row, second_row), axis=2)
first_row = np.concatenate((mosaic_images_list[0], mosaic_images_list[1]), axis=2)
second_row = np.concatenate((mosaic_images_list[2], mosaic_images_list[3]), axis=2)
mosaic = np.concatenate((first_row, second_row), axis=1)
return mosaic


Expand All @@ -38,9 +38,7 @@ def mosaic_creation(
labels: np.ndarray,
mosaics_per_class: int,
seed: Optional[int] = None,
) -> Tuple[
Any, List[Tuple[Any, ...]], List[Tuple[Any, ...]], List[Tuple[int, ...]], List[Any]
]:
) -> Tuple[Any, List[Tuple[Any, ...]], List[Tuple[Any, ...]], List[Tuple[int, ...]], List[Any]]:
"""
Build a mosaic dataset from an image dataset (images). Each mosaic corresponds to a 2x2 grid. Each one
is composed by four images: two belonging to the target class and the other two are chosen randomly from
Expand Down Expand Up @@ -89,33 +87,21 @@ def mosaic_creation(

target_class_images = images[labels == target_class]
target_class_image_indices = np.where(labels == target_class)[0]
target_class_images_and_indices = list(
zip(target_class_images, target_class_image_indices)
)

no_repetitions = int(
math.ceil((2 * mosaics_per_class) / len(target_class_images))
)
total_target_class_images_and_indices = (
target_class_images_and_indices * no_repetitions
)
target_class_images_and_indices = list(zip(target_class_images, target_class_image_indices))

no_repetitions = int(math.ceil((2 * mosaics_per_class) / len(target_class_images)))
total_target_class_images_and_indices = target_class_images_and_indices * no_repetitions
rng.shuffle(total_target_class_images_and_indices)

no_outer_images_per_class = int(
math.ceil((2 * mosaics_per_class) / len(outer_classes))
)
no_outer_images_per_class = int(math.ceil((2 * mosaics_per_class) / len(outer_classes)))
total_outer_images_and_indices = []
total_outer_labels = []
for outer_class in outer_classes:
outer_class_images = images[labels == outer_class]
outer_class_images_indices = np.where(labels == outer_class)[0]
outer_class_images_and_indices = list(
zip(outer_class_images, outer_class_images_indices)
)
outer_class_images_and_indices = list(zip(outer_class_images, outer_class_images_indices))

current_outer_images_and_indices = rng.choices(
outer_class_images_and_indices, k=no_outer_images_per_class
)
current_outer_images_and_indices = rng.choices(outer_class_images_and_indices, k=no_outer_images_per_class)
total_outer_images_and_indices += current_outer_images_and_indices
total_outer_labels += [outer_class] * no_outer_images_per_class

Expand All @@ -142,9 +128,7 @@ def mosaic_creation(
current_targets = tuple(elem[1] for elem in mosaic_elems)
mosaic_labels_list.append(current_targets)

current_p_batch = tuple(
int(elem[1] == target_class) for elem in mosaic_elems
)
current_p_batch = tuple(int(elem[1] == target_class) for elem in mosaic_elems)
p_batch_list.append(current_p_batch)

target_list.append(target_class)
Expand Down
66 changes: 61 additions & 5 deletions quantus/functions/perturb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,13 @@ def translation_x_direction(
np.moveaxis(arr, 0, -1),
matrix,
(arr.shape[1], arr.shape[2]),
borderValue=get_baseline_value(
value=perturb_baseline,
arr=arr,
return_shape=(arr.shape[0]),
**kwargs,
borderValue=float(
get_baseline_value(
value=perturb_baseline,
arr=arr,
return_shape=(arr.shape[0]),
**kwargs,
)[0]
),
)
arr_perturbed = np.moveaxis(arr_perturbed, -1, 0)
Expand Down Expand Up @@ -651,6 +653,60 @@ def translation_y_direction(
return arr_perturbed


def batched_translation(
arr: np.array,
perturb_baseline: Union[float, int, str, np.array],
perturb_dx: int = 10,
direction: str = "x",
**kwargs,
) -> np.array:
"""
Translate array by some given value in the x or y direction, assumes image type data and channel first layout.

Parameters
----------
arr: np.ndarray
Array to be perturbed.
perturb_baseline: float, int, str, np.ndarray
The baseline values to replace arr at indices with.
perturb_dy: integer
The translation length in features, e.g., pixels.
kwargs: optional
Keyword arguments.

Returns
-------
arr_perturbed: np.ndarray
The array which some of its indices have been perturbed.
"""
assert direction in {"x", "y"}, "direction must be one of {'x', 'y'}"
assert len(arr.shape) == 4, "Input arr must be a batch of 3D images"

arr_shape = arr.shape
batch_size, _, height, width = arr_shape
translated_axis = 3 if direction == "x" else 2

arr = arr.reshape(batch_size, -1)
translation_padding = get_baseline_value(
value=perturb_baseline,
arr=arr,
return_shape=tuple(arr.shape),
batched=True,
)
arr = arr.reshape(*arr_shape)
translation_padding = translation_padding.reshape(*arr_shape)
translation_padding = np.take(translation_padding, np.arange(-abs(perturb_dx), 0), axis=translated_axis)

original_dims = width if direction == "x" else height
if perturb_dx < 0:
x_batch_padded = np.concatenate([arr, translation_padding], axis=translated_axis)
x_batch_perturbed = np.take(x_batch_padded, np.arange(-original_dims, 0), axis=translated_axis)
else:
x_batch_padded = np.concatenate([translation_padding, arr], axis=translated_axis)
x_batch_perturbed = np.take(x_batch_padded, np.arange(original_dims), axis=translated_axis)
return x_batch_perturbed


def noisy_linear_imputation(
arr: np.array,
indices: Union[Sequence[int], Tuple[np.array]],
Expand Down
76 changes: 22 additions & 54 deletions quantus/metrics/axiomatic/completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from quantus.functions.perturb_func import baseline_replacement_by_indices
from quantus.functions.perturb_func import batch_baseline_replacement_by_indices
from quantus.helpers import warn
from quantus.helpers.enums import (
DataType,
Expand Down Expand Up @@ -133,23 +133,20 @@ def __init__(
**kwargs,
)
if perturb_func is None:
perturb_func = baseline_replacement_by_indices
perturb_func = batch_baseline_replacement_by_indices

# Save metric-specific attributes.
if output_func is None:
output_func = identity
self.output_func = output_func
self.perturb_func = make_perturb_func(
perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline
)
self.perturb_func = make_perturb_func(perturb_func, perturb_func_kwargs, perturb_baseline=perturb_baseline)

# Asserts and warnings.
if not self.disable_warnings:
warn.warn_parameterisation(
metric_name=self.__class__.__name__,
sensitive_params=(
"baseline value 'perturb_baseline' and the function to modify the "
"model response 'output_func'"
"baseline value 'perturb_baseline' and the function to modify the " "model response 'output_func'"
),
citation=(
"Sundararajan, Mukund, Ankur Taly, and Qiqi Yan. 'Axiomatic attribution for "
Expand Down Expand Up @@ -263,49 +260,6 @@ def __call__(
**kwargs,
)

def evaluate_instance(
self,
model: ModelInterface,
x: np.ndarray,
y: np.ndarray,
a: np.ndarray,
) -> bool:
"""
Evaluate instance gets model and data for a single instance as input and returns the evaluation result.

Parameters
----------
model: ModelInterface
A ModelInteface that is subject to explanation.
x: np.ndarray
The input to be evaluated on an instance-basis.
y: np.ndarray
The output to be evaluated on an instance-basis.
a: np.ndarray
The explanation to be evaluated on an instance-basis.

Returns
-------
score: boolean
The evaluation results.
"""
x_baseline = self.perturb_func(
arr=x, indices=np.arange(0, x.size), indexed_axes=np.arange(0, x.ndim)
)

# Predict on input.
x_input = model.shape_input(x, x.shape, channel_first=True)
y_pred = float(model.predict(x_input)[:, y])

# Predict on baseline.
x_input = model.shape_input(x_baseline, x.shape, channel_first=True)
y_pred_baseline = float(model.predict(x_input)[:, y])

if np.sum(a) == self.output_func(y_pred - y_pred_baseline):
return True
else:
return False

def evaluate_batch(
self,
model: ModelInterface,
Expand Down Expand Up @@ -337,7 +291,21 @@ def evaluate_batch(
The evaluation results.
"""

return [
self.evaluate_instance(model=model, x=x, y=y, a=a)
for x, y, a in zip(x_batch, y_batch, a_batch)
]
# Flatten the attributions.
x_batch_shape = x_batch.shape
batch_size = x_batch.shape[0]
a_batch = a_batch.reshape(batch_size, -1)
n_features = a_batch.shape[-1]
indices = np.stack([np.arange(n_features) for _ in x_batch])
x_baseline = self.perturb_func(arr=x_batch.reshape(batch_size, -1), indices=indices)

# Predict on input.
x_input = model.shape_input(x_batch, x_batch_shape, channel_first=True, batched=True)
y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]

# Predict on baseline.
x_baseline = x_baseline.reshape(*x_batch_shape)
x_input = model.shape_input(x_baseline, x_batch_shape, channel_first=True, batched=True)
y_pred_baseline = model.predict(x_input)[np.arange(batch_size), y_batch]

return a_batch.sum(axis=-1) == self.output_func(y_pred - y_pred_baseline)
Loading