Skip to content

Commit

Permalink
few arg corrections and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
sineeli committed Oct 24, 2024
1 parent f40cb9d commit d7939f6
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 160 deletions.
79 changes: 51 additions & 28 deletions keras/src/visualization/draw_bounding_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,65 @@ def draw_bounding_boxes(
font_scale=1.0,
data_format=None,
):
"""Utility to draw bounding boxes on the target image.
"""Draws bounding boxes on images.
Accepts a batch of images and batch of bounding boxes. The function draws
the bounding boxes onto the image, and returns a new image tensor with the
annotated images. This API is intentionally not exported, and is considered
an implementation detail.
This function draws bounding boxes on a batch of images. It supports
different bounding box formats and can optionally display class labels
and confidences.
Args:
images: a batch Tensor of images to plot bounding boxes onto.
bounding_boxes: a Tensor of batched bounding boxes to plot onto the
provided images.
bounding_box_format: The format of bounding boxes to plot onto the
images. Refer
[to the keras.io docs](TODO)
for more details on supported bounding box formats.
color: the color in which to plot the bounding boxes
class_mapping: Dictionary from class ID to class label. Defaults to
`None`.
line_thickness: Line thicknes for the box and text labels.
Defaults to `2`.
text_thickness: The thickness for the text. Defaults to `1.0`.
font_scale: Scale of font to draw in. Defaults to `1.0`.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
images: A batch of images as a 4D tensor or NumPy array. Shape should be
`(batch_size, height, width, channels)`.
bounding_boxes: A dictionary containing bounding box data. Should have
the following keys:
- `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)`
containing the bounding box coordinates in the specified format.
- `labels`: A tensor or array of shape `(batch_size, num_boxes)`
containing the class labels for each bounding box.
- `confidences` (Optional): A tensor or array of shape
`(batch_size, num_boxes)` containing the confidence scores for
each bounding box.
color: A tuple or list representing the RGB color of the bounding boxes.
For example, `(255, 0, 0)` for red.
bounding_box_format: A string specifying the format of the bounding
boxes. Refer [keras-io](TODO)
class_mapping: A dictionary mapping class IDs (integers) to class labels
(strings). Used to display class labels next to the bounding boxes.
Defaults to None (no labels displayed).
line_thickness: An integer specifying the thickness of the bounding box
lines. Defaults to `2`.
text_thickness: An integer specifying the thickness of the text labels.
Defaults to `1`.
font_scale: A float specifying the scale of the font used for text
labels. Defaults to `1.0`.
data_format: A string, either `"channels_last"` or `"channels_first"`,
specifying the order of dimensions in the input images. Defaults to
the `image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
"channels_last".
Returns:
the input `images` with provided bounding boxes plotted on top of them
A NumPy array of the annotated images with the bounding boxes drawn.
The array will have the same shape as the input `images`.
Raises:
ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is
not a dictionary, or if `bounding_boxes` does not contain `"boxes"`
and `"labels"` keys.
TypeError: If `bounding_boxes` is not a dictionary.
ImportError: If `cv2` (OpenCV) is not installed.
"""

if cv2 is None:
raise ImportError(
"The `draw_bounding_boxes` function requires the `cv2` package "
" (OpenCV). Please install it with `pip install opencv-python`."
)

class_mapping = class_mapping or {}
text_thickness = text_thickness or line_thickness
text_thickness = (
text_thickness or line_thickness
) # Default text_thickness if not provided.
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
Expand Down
39 changes: 38 additions & 1 deletion keras/src/visualization/draw_segmentation_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,50 @@ def draw_segmentation_masks(
ignore_index=-1,
data_format=None,
):
"""Draws segmentation masks on images.
The function overlays segmentation masks on the input images.
The masks are blended with the images using the specified alpha value.
Args:
images: A batch of images as a 4D tensor or NumPy array. Shape
should be (batch_size, height, width, channels).
segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor
or NumPy array. Shape should be (batch_size, height, width) or
(batch_size, height, width, 1). The values represent class indices
starting from 1 up to `num_classes`. Class 0 is reserved for
the background and will be ignored if `ignore_index` is not 0.
num_classes: The number of segmentation classes. If `None`, it is
inferred from the maximum value in `segmentation_masks`.
color_mapping: A dictionary mapping class indices to RGB colors.
If `None`, a default color palette is generated. The keys should be
integers starting from 1 up to `num_classes`.
alpha: The opacity of the segmentation masks. Must be in the range
`[0, 1]`.
ignore_index: The class index to ignore. Mask pixels with this value
will not be drawn. Defaults to -1.
data_format: Image data format, either `"channels_last"` or
`"channels_first"`. Defaults to the `image_data_format` value found
in your Keras config file at `~/.keras/keras.json`. If you never
set it, then it will be `"channels_last"`.
Returns:
A NumPy array of the images with the segmentation masks overlaid.
Raises:
ValueError: If the input `images` is not a 4D tensor or NumPy array.
TypeError: If the input `segmentation_masks` is not an integer type.
"""
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if data_format == "channels_first":
images = ops.transpose(images, (0, 2, 3, 1))
segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))
images = ops.convert_to_tensor(images, dtype="float32")
segmentation_masks = ops.convert_to_tensor(segmentation_masks)

Expand Down Expand Up @@ -65,4 +102,4 @@ def draw_segmentation_masks(

def _generate_color_palette(num_classes: int):
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [((i * palette) % 255).tolist() for i in range(num_classes)]
return [((i * palette) % 255).tolist() for i in range(1, num_classes + 1)]
134 changes: 81 additions & 53 deletions keras/src/visualization/plot_bounding_box_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from keras.src.visualization.plot_image_gallery import plot_image_gallery

try:
from matplotlib import patches
except:
from matplotlib import patches # For legend patches
except ImportError:
patches = None


Expand All @@ -34,57 +34,84 @@ def plot_bounding_box_gallery(
rows=None,
cols=None,
data_format=None,
**kwargs
**kwargs,
):
"""
"""Plots a gallery of images with bounding boxes.
This function can display both ground truth and predicted bounding boxes on
a set of images. It supports various bounding box formats and can include
class labels and a legend.
Args:
images: a Tensor or NumPy array containing images to show in the
gallery.
value_range: Value range of the images. Common examples include
`(0, 255)` and `(0, 1)`.
bounding_box_format: The bounding_box_format the provided bounding boxes
are in.
y_true: Bounding box dictionary representing the
ground truth bounding boxes and labels. Defaults to `None`
y_pred: Bounding box dictionary representing the
ground truth bounding boxes and labels. Defaults to `None`
pred_color: Three element tuple representing the color to use for
plotting predicted bounding boxes.
true_color: three element tuple representing the color to use for
plotting true bounding boxes.
class_mapping: Class mapping from class IDs to strings. Defaults to
`None`.
ground_truth_mapping: Class mapping from class IDs to
strings, defaults to `class_mapping`. Defaults to `None`
prediction_mapping: Class mapping from class IDs to strings.
Defaults to `class_mapping`.
line_thickness: Line thickness for the box and text labels.
Defaults to 2.
text_thickness: The line thickness for the text, defaults to
`1.0`.
font_scale: Font size to draw bounding boxes in.
legend: Whether to create a legend with the specified colors for
`y_true` and `y_pred`. Defaults to False.
rows: int. Number of rows in the gallery to shows. Required if inputs
are unbatched. Defaults to `None`
cols: int. Number of columns in the gallery to show. Required if inputs
are unbatched.Defaults to `None`
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
kwargs: keyword arguments to propagate to
`keras.visualization.plot_image_gallery()`.
images: A 4D tensor or NumPy array of images. Shape should be
`(batch_size, height, width, channels)`.
value_range: A tuple specifying the value range of the images
(e.g., `(0, 255)` or `(0, 1)`).
bounding_box_format: The format of the bounding boxes.
Refer [keras-io](TODO)
y_true: A dictionary containing the ground truth bounding boxes and
labels. Should have the same structure as the `bounding_boxes`
argument in `keras.visualization.draw_bounding_boxes`.
Defaults to `None`.
y_pred: A dictionary containing the predicted bounding boxes and labels.
Should have the same structure as `y_true`. Defaults to `None`.
true_color: A tuple of three integers representing the RGB color for the
ground truth bounding boxes. Defaults to `(0, 188, 212)`.
pred_color: A tuple of three integers representing the RGB color for the
predicted bounding boxes. Defaults to `(255, 235, 59)`.
line_thickness: The thickness of the bounding box lines. Defaults to 2.
font_scale: The scale of the font used for labels. Defaults to 1.0.
text_thickness: The thickness of the bounding box text. Defaults to
`line_thickness`.
class_mapping: A dictionary mapping class IDs to class names. Used f
or both ground truth and predicted boxes if `ground_truth_mapping`
and `prediction_mapping` are not provided. Defaults to `None`.
ground_truth_mapping: A dictionary mapping class IDs to class names
specifically for ground truth boxes. Overrides `class_mapping`
for ground truth. Defaults to `None`.
prediction_mapping: A dictionary mapping class IDs to class names
specifically for predicted boxes. Overrides `class_mapping` for
predictions. Defaults to `None`.
legend: A boolean indicating whether to show a legend.
Defaults to `False`.
legend_handles: A list of matplotlib `Patch` objects to use for the
legend. If this is provided, the `legend` argument will be ignored.
Defaults to `None`.
rows: The number of rows in the image gallery. Required if the images
are not batched. Defaults to `None`.
cols: The number of columns in the image gallery. Required if the images
are not batched. Defaults to `None`.
data_format: The image data format `"channels_last"` or
`"channels_first"`. Defaults to the Keras backend data format.
kwargs: Additional keyword arguments to be passed to
`keras.visualization.plot_image_gallery`.
Returns:
The output of `keras.visualization.plot_image_gallery`.
Raises:
ValueError: If `images` is not a 4D tensor/array or if both `legend` a
nd `legend_handles` are specified.
ImportError: if matplotlib is not installed
"""
if patches is None:
raise ImportError(
"The `plot_bounding_box_gallery` function requires the "
" `matplotlib` package. Please install it with "
" `pip install matplotlib`."
)

prediction_mapping = prediction_mapping or class_mapping
ground_truth_mapping = ground_truth_mapping or class_mapping
data_format = data_format or backend.image_data_format()

images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if data_format == "channels_first": # Ensure correct data format
images = ops.transpose(images, (0, 2, 3, 1))
plotted_images = ops.convert_to_numpy(images)

draw_fn = functools.partial(
Expand All @@ -93,7 +120,6 @@ def plot_bounding_box_gallery(
line_thickness=line_thickness,
text_thickness=text_thickness,
font_scale=font_scale,
data_format=data_format,
)

if y_true is not None:
Expand All @@ -106,22 +132,25 @@ def plot_bounding_box_gallery(

if y_pred is not None:
plotted_images = draw_fn(
plotted_images, y_pred, pred_color, class_mapping=prediction_mapping
plotted_images,
y_pred,
pred_color,
class_mapping=prediction_mapping,
)

if legend:
if legend_handles:
raise ValueError(
"Only pass `legend` OR `legend_handles` to "
"`luketils.visualization.plot_bounding_box_gallery()`."
"`keras.visualization.plot_bounding_box_gallery()`."
)
legend_handles = [
patches.Patch(
color=np.array(true_color) / 255.0,
color=np.array(true_color) / 255.0, # Normalize color
label="Ground Truth",
),
patches.Patch(
color=np.array(pred_color) / 255.0,
color=np.array(pred_color) / 255.0, # Normalize color
label="Prediction",
),
]
Expand All @@ -132,6 +161,5 @@ def plot_bounding_box_gallery(
legend_handles=legend_handles,
rows=rows,
cols=cols,
data_format=data_format,
**kwargs
**kwargs,
)
Loading

0 comments on commit d7939f6

Please sign in to comment.