Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched metrics - Continuation #357

Open
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

davor10105
Copy link
Contributor

@davor10105 davor10105 commented Nov 28, 2024

Hey @annahedstroem, as promised (albeit late again :< ) here's the PR including the rest of the batched metrics. The improvements can be found in the image below:
other_metrics
The remaining batched metrics did not yield significant speed-ups because their computations were relatively straightforward, leaving little room for further optimization. This was particularly true for the localization and complexity metrics. Regarding the randomization metrics, with the exception of the RandomLogit, I was unable to identify an effective approach to further reduce their runtime.

I also included changes to some of the metrics due to them not being implemented as described in their respective papers (results from the batched implementation are thus different from the current implementation's):
Non-Sensitivity - Now aligned correctly to the original paper here
Continuity - Similarly, aligned to it's definition here
Focus! - Bug fix - As I previously mentioned on Discord, it appears that the assumed dimensions of the mosaic in the current Quantus implementation are B x C x W x H. However, in reality, the order of the width and height is switched.
Consistency - Similarly, as mentioned in my comment on Discord, I have changed the implementation to align with the paper's definition of the metric.

Comment regarding Focus! I thought that the x_batch dimensions correspond to B x C x H x W, but looking at the code, for example quadrant_top_right it seems that the assumed dimensions are B x C x W x H. However, when I run the metric using the sample data and scripts provided in Quantus, and visualize the original image and the corresponding top right quadrant, the bottom left quadrant gets cropped instead, which aligns with the assumption that the dimensions are B x C x H x W. Is this an error in quadrant cropping in the current implementation or am I missing something here?
Comment regarding Consistency According to its definition, the Consistency metric "measures the expected local consistency, i.e., the probability that the prediction label for a given data point coincides with the prediction labels of other data points that share the same explanation." Given this, I would expect the metric to select instances within a batch that share the same explanation (where the "explanation label" is determined via the discretization function in Quantus) and then evaluate the proportion of matching model predictions among those instances.

However, when reviewing the current implementation in Quantus, I'm struggling to understand why the consistency check is performed by comparing the explanation with its label, rather than comparing the explanation labels directly. Could you clarify this for me?

Below is a code snippet from the current implementation in Quantus:

pred_a = y_pred_classes[i]
same_a = np.argwhere(a == a_label).flatten()
diff_a = same_a[same_a != i]
pred_same_a = y_pred_classes[diff_a]

if len(same_a) == 0:
return 0
return np.sum(pred_same_a == pred_a) / len(diff_a)

And this is the way I thought the metric should work (batched):

batch_size = y_pred_classes.shape[0]
pred_classes_equality = y_pred_classes[None] == y_pred_classes[:, None]
pred_classes_equality *= (1 - np.eye(batch_size)).astype(bool)
a_labels_equality = a_label_batch[None] == a_label_batch[:, None]
a_labels_equality *= (1 - np.eye(batch_size)).astype(bool)

return (pred_classes_equality * a_labels_equality).sum(axis=-1) / (a_labels_equality.sum(axis=-1) + 1e-9)

Thank you for the help!

As always, let me know if you have any questions!

@annahedstroem
Copy link
Member

Hi @davor10105, really amazing work!

  1. Thank you so much for identifying the bug in the Focus! metric — while it was implemented by the original authors, it is still possible that such bugs sneaks in! Great find and thanks for including the image sample too.
  2. @dilyabareeva and I were discussing a possible implementation at the time. Since we didn't have the original author's implementation, we made our best interpretation of their method at the time. @dilyabareeva, do you have some comments on @davor10105's suggestion? I would be curious to hear what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants