Skip to content

Commit

Permalink
Handle empty change indices in SAM's mask to rle conversion (huggingf…
Browse files Browse the repository at this point in the history
…ace#35665)

* Handle empty change indices in RLE conversion for masks

* [test] Add unit tests for RLE encoding of masks in SamProcessor

* [test] Update RLE conversion tests to use TensorFlow implementation

* [test] Fix formatting in SamProcessorTest according to check_code_quality action

* [test] Fix formatting in SamProcessorTest according to check_code_quality

* [test] Refactored rle test cases into one test and used tf tensors in tf test cases

* [test] Fix: removed self parameter from refactored methods

* [test] Removed nested methods in run-length encoding tests for PyTorch and TensorFlow

* [test] Added description to individual to run-length encoding tests for PyTorch and TensorFlow.
  • Loading branch information
MSt-10 authored and bursteratom committed Feb 5, 2025
1 parent 18d1de9 commit 7a9c652
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/transformers/models/sam/image_processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,14 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
if len(cur_idxs) == 0:
# No changes => either all 0 or all 1
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
if input_mask[i, 0] == 0:
out.append({"size": [height, width], "counts": [height * width]})
else:
out.append({"size": [height, width], "counts": [0, height * width]})
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
Expand All @@ -1396,6 +1404,14 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
out = []
for i in range(batch_size):
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
if len(cur_idxs) == 0:
# No changes => either all 0 or all 1
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
if input_mask[i, 0] == 0:
out.append({"size": [height, width], "counts": [height * width]})
else:
out.append({"size": [height, width], "counts": [0, height * width]})
continue
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if input_mask[i, 0] == 0 else [0]
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
Expand Down
76 changes: 76 additions & 0 deletions tests/models/sam/test_processor_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@
if is_torch_available():
import torch

from transformers.models.sam.image_processing_sam import _mask_to_rle_pytorch

if is_tf_available():
import tensorflow as tf

from transformers.models.sam.image_processing_sam import _mask_to_rle_tf


@require_vision
@require_torchvision
Expand Down Expand Up @@ -161,6 +165,42 @@ def test_post_process_masks(self):
with self.assertRaises(ValueError):
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))

def test_rle_encoding(self):
"""
Test the run-length encoding function.
"""
# Test that a mask of all zeros returns a single run [height * width].
input_mask = torch.zeros((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2
rle = _mask_to_rle_pytorch(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
# For a 2x2 all-zero mask, we expect a single run of length 4:
self.assertEqual(rle[0]["counts"], [4])

# Test that a mask of all ones returns [0, height * width].
input_mask = torch.ones((1, 2, 2), dtype=torch.long) # shape: 1 x 2 x 2
rle = _mask_to_rle_pytorch(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
# For a 2x2 all-one mask, we expect two runs: [0, 4].
self.assertEqual(rle[0]["counts"], [0, 4])

# Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct.
# Example mask:
# Row 0: [0, 1]
# Row 1: [1, 1]
# This is shape (1, 2, 2).
# Flattened in Fortran order -> [0, 1, 1, 1].
# The RLE for [0,1,1,1] is [1, 3].
input_mask = torch.tensor([[[0, 1], [1, 1]]], dtype=torch.long)
rle = _mask_to_rle_pytorch(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones


@require_vision
@require_tf
Expand Down Expand Up @@ -244,6 +284,42 @@ def test_post_process_masks(self):
dummy_masks, np.array(original_sizes), np.array(reshaped_input_size), return_tensors="tf"
)

def test_rle_encoding(self):
"""
Test the run-length encoding function.
"""
# Test that a mask of all zeros returns a single run [height * width].
input_mask = tf.zeros((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2
rle = _mask_to_rle_tf(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
# For a 2x2 all-zero mask, we expect a single run of length 4:
self.assertEqual(rle[0]["counts"], [4])

# Test that a mask of all ones returns [0, height * width].
input_mask = tf.ones((1, 2, 2), dtype=tf.int64) # shape: 1 x 2 x 2
rle = _mask_to_rle_tf(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
# For a 2x2 all-one mask, we expect two runs: [0, 4].
self.assertEqual(rle[0]["counts"], [0, 4])

# Test a mask with mixed 0s and 1s to ensure the run-length encoding is correct.
# Example mask:
# Row 0: [0, 1]
# Row 1: [1, 1]
# This is shape (1, 2, 2).
# Flattened in Fortran order -> [0, 1, 1, 1].
# The RLE for [0,1,1,1] is [1, 3].
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
rle = _mask_to_rle_tf(input_mask)

self.assertEqual(len(rle), 1)
self.assertEqual(rle[0]["size"], [2, 2])
self.assertEqual(rle[0]["counts"], [1, 3]) # 1 zero, followed by 3 ones


@require_vision
@require_torchvision
Expand Down

0 comments on commit 7a9c652

Please sign in to comment.