Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeanIoU differ from custom IOU metrics implementation #20574

Closed
edge7 opened this issue Dec 1, 2024 · 12 comments
Closed

MeanIoU differ from custom IOU metrics implementation #20574

edge7 opened this issue Dec 1, 2024 · 12 comments
Assignees
Labels

Comments

@edge7
Copy link
Contributor

edge7 commented Dec 1, 2024

Hi,
am running a segmentation training process and am using the following function as IoU Custom metrics:

@keras.saving.register_keras_serializable(package="glass_segm", name="custom_iou_metric")
def custom_iou_metric(y_true, y_pred, num_classes=3):
    y_pred = tf.argmax(y_pred, axis=-1)

    y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
    y_pred = tf.cast(tf.reshape(y_pred, [-1]), tf.int32)

    iou = tf.constant(0.0)

    for i in range(num_classes):
        true_mask = tf.cast(tf.equal(y_true, i), tf.float32)
        pred_mask = tf.cast(tf.equal(y_pred, i), tf.float32)

        intersection = tf.reduce_sum(true_mask * pred_mask)
        union = tf.reduce_sum(true_mask) + tf.reduce_sum(pred_mask) - intersection

        class_iou = tf.cond(
            tf.equal(union, 0), lambda: tf.constant(1.0), lambda: intersection / union
        )
        iou += class_iou

    iou /= tf.cast(num_classes, tf.float32)

    return iou

I was expecting that your MeanIoU would be the IoU mean across the classes and also the mean all the training and validation set batches, but it does not seem like that.
example:

given this:

y_pred = np.array([[[0.8, 0.1, 0.1], [0.7, 0.2, 0.1], [0.1, 0.7, 0.2], [0.2, 0.6, 0.2]],
 [[0.6, 0.3, 0.1], [0.7, 0.2, 0.1], [0.2, 0.5, 0.3], [0.3, 0.4, 0.3]],
 [[0.7, 0.2, 0.1], [0.3, 0.6, 0.1], [0.4, 0.4, 0.2], [0.3, 0.3, 0.4]],
 [[0.5, 0.4, 0.1], [0.3, 0.5, 0.2], [0.3, 0.3, 0.4], [0.4, 0.5, 0.1]]]

)
y_true = np.array([[0, 0, 2, 1],
 [1, 0, 1, 2],
 [1, 2, 2, 0],
 [1, 2, 0, 0]]
)

If I run:

m = keras.metrics.MeanIoU(num_classes=3,sparse_y_pred=False)
for i in range(1):
    y_true[0][0] = i % 2
    y_true[1][0] = i % 2
    m.update_state(y_true, y_pred)
m.result()

and then:

import numpy as np
ll = []
for i in range(1):
 y_true[0][0] = i % 2
 y_true[1][0] = i % 2
 ll.append(custom_iou_metric(y_true,y_pred))
np.mean(ll)

the result is the same.
But If I increase the range they diverge a bit. What is the intuition behind summing confusion matrixes as you do? That is not exactly the average.
In my example, if you increase the range they diverge but not that much, but I see big differences during training:

model.compile(
            optimizer=optimizer,
            loss=focal_loss(),
            metrics=[
                "accuracy",
                custom_iou_metric,
                keras.metrics.MeanIoU(num_classes=3,sparse_y_pred=False)
            ],
        )

I am using a custom data generator, if that matters.

Thanks for the clarificaitons

@edge7
Copy link
Contributor Author

edge7 commented Dec 1, 2024

I am reading around and it looks like that approach is more of a global one. Am just heavily surprised that the results might diverge a lot

@mehtamansi29
Copy link
Collaborator

Hi @edge7-

Thanks for reporting the issue. Can you help me with sample training data where you are able to see diverge in result ?

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

Hi,
As said, the 2 approaches (averaging batch by batch and summing up the confusion matrix) are different by nature, but the difference I see here is probably too huge.
Custom metrics (i.e.: averaging batch by batch):

image

Keras MeanIoU metrics:

image

Do you need a training-set subset? If so, that can be hard to give you, as it's sort of proprietary data.

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

Just one more note:

image

Is this normal?

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

Hi, am digging in as it looks like this is not reliable at all.
First of all with Jax and Tensorflow as backend, for some obscure reason the sum of the confusion matrix is wrong (not always but I've seen a situation where it is ).
Then, another problem is here:

indices = ops.stack([labels, predictions], axis=1)
 values = ops.ones_like(predictions, dtype) if weights is None else weights
 indices = ops.cast(indices, dtype="int64")
 values = ops.cast(values, dtype=self.dtype)
 num_classes = int(num_classes)
 confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))
 return confusion_matrix

the scatter is randomly wrong as well. It gets better if one changes the type like this:

values = ops.cast(values, dtype='int64')

@mehtamansi29
Copy link
Collaborator

Hi @edge7 -

Can you help me with full sample code with relevant error traceback to dig into the issue more ?

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

This is a way to reproduce:

import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import keras
from sklearn.metrics import confusion_matrix


def compare_matrices(matrix1, matrix2):

    differences = matrix1 != matrix2
    if np.any(differences):
        print("Differences found at:")
        for row, col in np.argwhere(differences):
            print(
                f"Cell ({row}, {col}): Matrix1 = {matrix1[row, col]}, Matrix2 = {matrix2[row, col]}"
            )
        return True
    return False


all_y_true = np.load("/home/edge7/Downloads/all_y_true.npy")
all_y_pred = np.load("/home/edge7/Downloads/all_y_pred.npy")import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import keras
from sklearn.metrics import confusion_matrix


def compare_matrices(matrix1, matrix2):

    differences = matrix1 != matrix2
    if np.any(differences):
        print("Differences found at:")
        for row, col in np.argwhere(differences):
            print(
                f"Cell ({row}, {col}): Matrix1 = {matrix1[row, col]}, Matrix2 = {matrix2[row, col]}"
            )
        return True
    return False


all_y_true = np.load("/home/edge7/Downloads/all_y_true.npy")
all_y_pred = np.load("/home/edge7/Downloads/all_y_pred.npy")

all_y_pred_arg = np.argmax(all_y_pred, axis=-1)


all_y_pred_arg = np.argmax(all_y_pred, axis=-1)

for i in range(590, 4000):
    total_cm = None  # Initialize total confusion matrix
    mean_iou_metric = keras.metrics.MeanIoU(num_classes=3, sparse_y_pred=False)

    # Update metric and calculate confusion matrix for each slice
    for j in range(i):

        # Update MeanIoU metric
        mean_iou_metric.update_state(all_y_true[j], all_y_pred[j])

        # Flatten data for confusion matrix calculation
        tmp_true = np.reshape(all_y_true[j], -1)
        tmp_pred = np.reshape(all_y_pred_arg[j], -1)
        tmp_confusion = confusion_matrix(tmp_true, tmp_pred, labels=np.arange(3))

        # Accumulate confusion matrix
        if total_cm is None:
            total_cm = tmp_confusion
        else:
            total_cm += tmp_confusion

        # Ensure consistency between accumulated confusion matrices
        try:
            assert np.array_equal(total_cm, mean_iou_metric.total_cm.numpy())
        except AssertionError:
            print("Results are different at index:", j)
    print(i)
    # Calculate final MeanIoU result for this range
    result1 = round(mean_iou_metric.result().numpy(), 3)
    conf_matrix_a = mean_iou_metric.total_cm.numpy()

    # Alternative calculation over the entire slice range
    mean_iou_metric.reset_state()
    mean_iou_metric.update_state(all_y_true[0:i, :, :], all_y_pred[0:i, :, :])
    conf_matrix_b = mean_iou_metric.total_cm.numpy()
    result2 = round(mean_iou_metric.result().numpy(), 3)

    # Validate confusion matrices and results
    tmp_true = np.reshape(all_y_true[0:i], -1)
    tmp_pred = np.reshape(all_y_pred_arg[0:i], -1)
    tmp_confusion = confusion_matrix(tmp_true, tmp_pred, labels=np.arange(3))

    if compare_matrices(conf_matrix_a, conf_matrix_b):
        print(f"Inconsistency found at range: {i}")
        break

    if result1 != result2:
        print(f"MeanIoU mismatch: {result1} vs {result2} at range: {i}")
        break

y_pred is here
y_true is here

In the first part, you can see the very weird sum, and in the second, the scatter problem. The scatter problem gets solved with:
values = ops.cast(values, dtype='int64')

Happy to jump on a call if needed. The files are pretty heavy, but If I subset too much, the error does not show up anymore.

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

Last bit from my side:

This is slightly different code

The output with values = ops.cast(values, dtype=dtype) in metrics utils (current version):

Differences found at:
Cell (0, 0): Matrix1 = 16778371, Matrix2 = 16778372.0
Results are different at index: 592
Differences found at:
Cell (0, 0): Matrix1 = 16778372.0, Matrix2 = 16777216.0
Checking with SkLearn (Matrix 1)
Differences found at:
Cell (0, 0): Matrix1 = 16778371, Matrix2 = 16777216.0
Inconsistency found at range: 593

If I change to values = ops.cast(values, dtype="int64")

I get:

Differences found at:
Cell (0, 0): Matrix1 = 16778371, Matrix2 = 16778372.0
Results are different at index: 592
Differences found at:
Cell (0, 0): Matrix1 = 16778371, Matrix2 = 16778372.0
Results are different at index: 592
Differences found at:
Cell (0, 0): Matrix1 = 16800742, Matrix2 = 16800744.0
Results are different at index: 593
Differences found at:
Cell (0, 0): Matrix1 = 16800744.0, Matrix2 = 16800742.0
Checking with SkLearn (Matrix 1)
Inconsistency found at range: 594

which means that scatter gets messed up with float. The weird add behaviour is still there, though (which is what matters during training metrics generation).

❯ pip list | grep keras
keras 3.7.0
keras-tuner 1.4.7

❯ pip list | grep tensorflow
tensorflow 2.18.0
tensorflow-datasets 4.9.6
tensorflow-io-gcs-filesystem 0.37.1
tensorflow-metadata 1.15.0

@edge7
Copy link
Contributor Author

edge7 commented Dec 2, 2024

I have done one more step ahead.
You can use these, that are smaller:

y_true_s
y_pred_s

This is a slightly cleaner code to reproduce.

If I run it as it is, I get:

Differences found at:
Cell (0, 0): sklearn = 16778371, keras sum = 16778372.0
Results are different at index: 592 593
Differences found at:
Cell (0, 0): keras sum = 16778372.0, keras one go = 16777216.0
Checking with SkLearn
Differences found at:
Cell (0, 0): sklearn = 16778371, keras one go = 16777216.0
Inconsistency found at range: 593

from which we understand that for some reason the keras sum is wrong for element 0,0 in the confusion matrix, but also that even in the second part where the confusion matrix is not added up (as there is just one big slice) is wrong.
I think this depends on:
confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))

If I then modify this line
to:

current_cm = confusion_matrix(
            y_true,
            y_pred,
            self.num_classes,
            weights=sample_weight,
            dtype='int64',
        )

I get:

Differences found at:
Cell (0, 0): sklearn = 16778371, keras sum = 16778372.0
Results are different at index: 592 593
Differences found at:
Cell (0, 0): sklearn = 16778371, keras sum = 16778372.0
Results are different at index: 592 594
Differences found at:
Cell (0, 0): sklearn = 16800742, keras sum = 16800744.0
Results are different at index: 593 594
Differences found at:
Cell (0, 0): keras sum = 16800744.0, keras one go = 16800742.0
Checking with SkLearn
Inconsistency found at range: 594

this implies that the sum issue is still there but at least the "one go" works now as it is identical to SkLearn.

I think now @mehtamansi29 you have really all the info I could give you.

@edge7
Copy link
Contributor Author

edge7 commented Dec 3, 2024

Hi @mehtamansi29

Through this snippet you should be able to reproduce the issue without external data:

`import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import keras


def confusion_matrix(y_true, y_pred, num_classes):
    """
    Creates a confusion matrix as a numpy array.

    Parameters:
    - y_true: array-like, true class labels.
    - y_pred: array-like, predicted class labels.
    - num_classes: int, number of classes.

    Returns:
    - conf_matrix: np.ndarray, confusion matrix of shape (num_classes, num_classes).
    """
    conf_matrix = np.zeros((num_classes, num_classes), dtype=int)
    for true, pred in zip(y_true, y_pred):
        conf_matrix[true, pred] += 1
    return conf_matrix


def compare_matrices(matrix1, matrix2, m1, m2):
    """
    Compares two matrices and prints the locations of differences, if any.

    Args:
        matrix1 (np.array): First matrix.
        matrix2 (np.array): Second matrix.

    Returns:
        bool: True if differences are found, False otherwise.
    """
    differences = matrix1 != matrix2
    if np.any(differences):
        print("Differences found at:")
        for row, col in np.argwhere(differences):
            print(
                f"Cell ({row}, {col}): {m1} = {matrix1[row, col]}, {m2} = {matrix2[row, col]}"
            )
        return True
    return False


# Set a deterministic seed
np.random.seed(14)

all_y_true = np.random.choice([0, 1, 2], size=(600, 600, 600))

# Generate random probabilities for each channel
random_probs = np.random.rand(600, 600, 600, 3)
# Normalize to ensure the last dimension sums to 1
all_y_pred = random_probs / random_probs.sum(axis=-1, keepdims=True)
# Convert predictions to class indices
all_y_pred_arg = np.argmax(all_y_pred, axis=-1)

START_AT = 0
# Iterate over different slice ranges
for i in range(450, 600):
    total_cm = None  # Initialize total confusion matrix
    mean_iou_metric = keras.metrics.MeanIoU(num_classes=3)

    # Update metric and calculate confusion matrix for each slice
    for j in range(START_AT, i):
        # Update MeanIoU metric
        mean_iou_metric.update_state(all_y_true[j], all_y_pred_arg[j])

        # Flatten data for confusion matrix calculation
        tmp_true = np.reshape(all_y_true[j], -1)
        tmp_pred = np.reshape(all_y_pred_arg[j], -1)
        tmp_confusion = confusion_matrix(tmp_true, tmp_pred, 3)

        # Accumulate confusion matrix
        if total_cm is None:
            total_cm = tmp_confusion
        else:
            total_cm += tmp_confusion

        # Ensure consistency between accumulated confusion matrices
        try:
            assert np.array_equal(total_cm, mean_iou_metric.total_cm.numpy())
        except AssertionError:
            compare_matrices(
                total_cm, mean_iou_metric.total_cm.numpy(), "manual", "keras sum"
            )
            print("Results are different at index:", j, i)
    # Calculate final MeanIoU result for this range
    result1 = round(mean_iou_metric.result().numpy(), 3)
    conf_matrix_a = mean_iou_metric.total_cm.numpy()

    mean_iou_metric_all = keras.metrics.MeanIoU(num_classes=3)
    # Alternative calculation over the entire slice range
    mean_iou_metric_all.reset_state()
    mean_iou_metric_all.update_state(all_y_true[START_AT:i], all_y_pred_arg[START_AT:i])
    conf_matrix_b = mean_iou_metric_all.total_cm.numpy()
    result2 = round(mean_iou_metric_all.result().numpy(), 3)

    # Validate confusion matrices and results
    tmp_true = np.reshape(all_y_true[START_AT:i], -1)
    tmp_pred = np.reshape(all_y_pred_arg[START_AT:i], -1)
    tmp_confusion = confusion_matrix(tmp_true, tmp_pred, 3)

    if compare_matrices(conf_matrix_a, conf_matrix_b, "keras sum", "keras one go"):
        print("Checking with Manual")
        compare_matrices(tmp_confusion, conf_matrix_b, "Manual", "keras one go")
        print(f"Inconsistency found at range: {i}")
        break

    if result1 != result2:
        print(f"MeanIoU mismatch: {result1} vs {result2} at range: {i}")
        break
`

This change partially fixes the problem:

current_cm = confusion_matrix(
            y_true,
            y_pred,
            self.num_classes,
            weights=sample_weight,
            dtype='int64',
        )

There is still something different at the end. The summing part (at the beginning) is very very weird.

@edge7
Copy link
Contributor Author

edge7 commented Dec 10, 2024

fixed by #20584

@edge7 edge7 closed this as completed Dec 10, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants