Skip to content

Commit

Permalink
add segmentation mask draw and plot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Oct 24, 2024
1 parent 13ae91f commit f40cb9d
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 2 deletions.
6 changes: 6 additions & 0 deletions keras/api/_tf_keras/keras/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
6 changes: 6 additions & 0 deletions keras/api/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
8 changes: 8 additions & 0 deletions keras/src/visualization/draw_bounding_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def draw_bounding_boxes(
Defaults to `2`.
text_thickness: The thickness for the text. Defaults to `1.0`.
font_scale: Scale of font to draw in. Defaults to `1.0`.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
Returns:
the input `images` with provided bounding boxes plotted on top of them
Expand Down
68 changes: 68 additions & 0 deletions keras/src/visualization/draw_segmentation_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export


@keras_export("keras.visualization.draw_segmentation_masks")
def draw_segmentation_masks(
images,
segmentation_masks,
num_classes=None,
color_mapping=None,
alpha=0.8,
ignore_index=-1,
data_format=None,
):
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
images = ops.convert_to_tensor(images, dtype="float32")
segmentation_masks = ops.convert_to_tensor(segmentation_masks)

if not backend.is_int_dtype(segmentation_masks.dtype):
dtype = backend.standardize_dtype(segmentation_masks.dtype)
raise TypeError(
"`segmentation_masks` must be in integer dtype. "
f"Received: segmentation_masks.dtype={dtype}"
)

# Infer num_classes
if num_classes is None:
num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))
if color_mapping is None:
colors = _generate_color_palette(num_classes)
else:
colors = [color_mapping[i] for i in range(num_classes)]
valid_masks = ops.not_equal(segmentation_masks, ignore_index)
valid_masks = ops.squeeze(valid_masks, axis=-1)
segmentation_masks = ops.one_hot(segmentation_masks, num_classes)
segmentation_masks = segmentation_masks[..., 0, :]
segmentation_masks = ops.convert_to_numpy(segmentation_masks)

# Replace class with color
masks = segmentation_masks
masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool")
images_to_draw = ops.convert_to_numpy(images).copy()
for mask, color in zip(masks, colors):
color = np.array(color, dtype=images_to_draw.dtype)
images_to_draw[mask, ...] = color[None, :]
images_to_draw = ops.convert_to_tensor(images_to_draw)
images_to_draw = ops.cast(images_to_draw, dtype="float32")

# Apply blending
outputs = images * (1 - alpha) + images_to_draw * alpha
outputs = ops.where(valid_masks[..., None], outputs, images)
outputs = ops.cast(outputs, dtype="uint8")
outputs = ops.convert_to_numpy(outputs)
return outputs


def _generate_color_palette(num_classes: int):
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [((i * palette) % 255).tolist() for i in range(num_classes)]
21 changes: 19 additions & 2 deletions keras/src/visualization/plot_bounding_box_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
Expand Down Expand Up @@ -30,8 +31,9 @@ def plot_bounding_box_gallery(
prediction_mapping=None,
legend=False,
legend_handles=None,
rows=3,
cols=3,
rows=None,
cols=None,
data_format=None,
**kwargs
):
"""
Expand Down Expand Up @@ -63,12 +65,25 @@ def plot_bounding_box_gallery(
font_scale: Font size to draw bounding boxes in.
legend: Whether to create a legend with the specified colors for
`y_true` and `y_pred`. Defaults to False.
rows: int. Number of rows in the gallery to shows. Required if inputs
are unbatched. Defaults to `None`
cols: int. Number of columns in the gallery to show. Required if inputs
are unbatched.Defaults to `None`
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
kwargs: keyword arguments to propagate to
`keras.visualization.plot_image_gallery()`.
"""

prediction_mapping = prediction_mapping or class_mapping
ground_truth_mapping = ground_truth_mapping or class_mapping
data_format = data_format or backend.image_data_format()

plotted_images = ops.convert_to_numpy(images)

Expand All @@ -78,6 +93,7 @@ def plot_bounding_box_gallery(
line_thickness=line_thickness,
text_thickness=text_thickness,
font_scale=font_scale,
data_format=data_format,
)

if y_true is not None:
Expand Down Expand Up @@ -116,5 +132,6 @@ def plot_bounding_box_gallery(
legend_handles=legend_handles,
rows=rows,
cols=cols,
data_format=data_format,
**kwargs
)
13 changes: 13 additions & 0 deletions keras/src/visualization/plot_image_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
Expand Down Expand Up @@ -42,6 +43,7 @@ def plot_image_gallery(
transparent=True,
dpi=60,
legend_handles=None,
data_format=None,
):
"""Displays a gallery of images.
Expand All @@ -63,6 +65,14 @@ def plot_image_gallery(
legend_handles: (Optional) matplotlib.patches List of legend handles.
I.e. passing: `[patches.Patch(color='red', label='mylabel')]` will
produce a legend with a single red patch and the label 'mylabel'.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
"""

if path is not None and show:
Expand All @@ -72,6 +82,7 @@ def plot_image_gallery(
)
# set show to True by default if path is None
show = True if path is None else False
data_format = data_format or backend.image_data_format()

batch_size = (
ops.shape(images)[0] if len(ops.shape(images)) == 4 else 1
Expand Down Expand Up @@ -107,6 +118,8 @@ def plot_image_gallery(
)

images = ops.convert_to_numpy(images)
if data_format == "channels_first":
images = images.transpose(0, 3, 1, 2)

for row in range(rows):
for col in range(cols):
Expand Down
84 changes: 84 additions & 0 deletions keras/src/visualization/plot_segmentation_mask_gallery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import functools

import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery


@keras_export("keras.visualization.plot_segmentation_mask_gallery")
def plot_segmentation_mask_gallery(
images,
value_range,
num_classes,
y_true=None,
y_pred=None,
rows=None,
cols=None,
color_mapping=None,
data_format=None,
**kwargs
):
"""Plots a gallery of images with corresponding segmentation masks.
Args:
images: a Tensor or NumPy array containing images to show in the
gallery. The images should be batched and of shape (B, H, W, C).
value_range: value range of the images. Common examples include
`(0, 255)` and `(0, 1)`.
num_classes: number of segmentation classes.
y_true: A Tensor or NumPy array representing the ground truth
segmentation masks. The ground truth segmentation maps should be
batched.
y_pred: A Tensor or NumPy array representing the predicted
segmentation masks. The predicted segmentation masks should be
batched.
rows: int. Number of rows in the gallery to shows. Required if inputs
are unbatched. Defaults to `None`
cols: int. Number of columns in the gallery to show. Required if inputs
are unbatched.Defaults to `None`
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
kwargs: keyword arguments to propagate to
`keras.visualization.plot_image_gallery()`.
"""
data_format = data_format or backend.image_data_format()
plotted_images = ops.convert_to_numpy(images)
masks_to_contatenate = [plotted_images]

draw_fn = functools.partial(
draw_segmentation_masks,
num_classes=num_classes,
color_mapping=color_mapping,
data_format=data_format,
)

if y_true is not None:
plotted_y_true = draw_fn(plotted_images, y_true)
masks_to_contatenate.append(plotted_y_true)

if y_pred is not None:
plotted_y_pred = draw_fn(plotted_images, y_pred)
masks_to_contatenate.append(plotted_y_pred)

# Concatenate the images and the masks together.
plotted_images = np.concatenate(masks_to_contatenate, axis=2)

return plot_image_gallery(
plotted_images,
value_range=value_range,
rows=rows,
cols=cols,
data_format=data_format,
)

0 comments on commit f40cb9d

Please sign in to comment.