diff --git a/src/otx/algo/instance_segmentation/heads/__init__.py b/src/otx/algo/instance_segmentation/heads/__init__.py index 91cafe6b4aa..17066a70e1f 100644 --- a/src/otx/algo/instance_segmentation/heads/__init__.py +++ b/src/otx/algo/instance_segmentation/heads/__init__.py @@ -4,5 +4,6 @@ """Custom head architecture for OTX instance segmentation models.""" from .custom_roi_head import CustomConvFCBBoxHead, CustomRoIHead +from .custom_rtmdet_ins_head import CustomRTMDetInsSepBNHead -__all__ = ["CustomRoIHead", "CustomConvFCBBoxHead"] +__all__ = ["CustomRoIHead", "CustomConvFCBBoxHead", "CustomRTMDetInsSepBNHead"] diff --git a/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py new file mode 100644 index 00000000000..948c6a4de3c --- /dev/null +++ b/src/otx/algo/instance_segmentation/heads/custom_rtmdet_ins_head.py @@ -0,0 +1,407 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Custom RTMDetInsSepBNHead for OTX RTMDet-InstSeg instance segmentation models.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F # noqa: N812 +from mmcv.ops import RoIAlign, batched_nms +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.models.dense_heads.rtmdet_ins_head import _parse_dynamic_params +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmcv.ops.nms import multiclass_nms +from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsSepBNHead +from mmdet.registry import MODELS +from mmdet.structures.bbox import get_box_tensor, get_box_wh, scale_boxes +from mmengine.config import ConfigDict + +if TYPE_CHECKING: + from mmengine.structures import InstanceData + + +@MODELS.register_module() +class CustomRTMDetInsSepBNHead(RTMDetInsSepBNHead): + """Custom RTMDet instance segmentation head. + + Note: In comparison to the original RTMDetInsSepBNHead, this class overrides the _bbox_mask_post_process + to conduct mask post-processing by chunking the masks into smaller chunks and processing them individually. + This approach mitigates the risk of running out of memory, particularly when handling a large number of masks. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.roi_align = RoIAlign(output_size=(28, 28)) + + def mask_postprocess( + self, + mask_logits: torch.Tensor, + img_h: int, + img_w: int, + gpu_mem_limit: float = 1.0, + threshold: float = 0.5, + ) -> torch.Tensor: + """Postprocess mask logits to binary masks. + + Args: + masks (_type_): Mask Logits with shape (B, N, H, W). + img_h (int): Image height resizes to. + img_w (int): Image width resizes to. + gpu_mem_limit (float, optional): GPU memory limit in GB. Defaults to 1.0. + threshold (float, optional): Threshold for binary masks. Defaults to 0.5. + + Returns: + torch.Tensor: Binary masks with shape (B, N, img_h, img_w). + """ + masks = torch.zeros( + size=(mask_logits.shape[:2] + (img_h, img_w)), + dtype=torch.bool, + device=mask_logits.device, + ) + + total_bytes = mask_logits.element_size() * masks.nelement() + num_chunks = int(math.ceil(total_bytes / (gpu_mem_limit * 1024) ** 3)) + n = mask_logits.shape[1] + chunks = torch.chunk( + torch.arange(n, device=mask_logits.device), + num_chunks, + ) + mask_logits = mask_logits.sigmoid() + for inds in chunks: + masks[:, inds] = ( + F.interpolate( + mask_logits[:, inds], + size=[ + img_w, + img_h, + ], + mode="bilinear", + align_corners=False, + ) + >= threshold + ).to(dtype=torch.bool) + return masks + + def _bbox_mask_post_process( + self, + results: InstanceData, + mask_feat: torch.Tensor, + cfg: ConfigDict | dict, + rescale: bool = False, + with_nms: bool = True, + img_meta: dict | None = None, + ) -> InstanceData: + """Bbox and mask post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + mask_feat (Tensor): Mask prototype features of a single image + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). + """ + if img_meta is None: + img_meta = {} + img_meta["scale_factor"] = [1.0, 1.0] + + if cfg is None: + cfg = ConfigDict( + nms_pre=300, + mask_thr_binary=0.5, + max_per_img=100, + score_thr=0.05, + nms=ConfigDict(type="nms", iou_threshold=0.6), + min_bbox_size=0, + ) + + stride = self.prior_generator.strides[0][0] + if rescale: + scale_factor = [1 / s for s in img_meta["scale_factor"]] + results.bboxes = scale_boxes(results.bboxes, scale_factor) + + if hasattr(results, "score_factors"): + score_factors = results.pop("score_factors") + results.scores = results.scores * score_factors + + # filter small size bboxes + if cfg.min_bbox_size >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + if not with_nms: + msg = "with_nms must be True for RTMDet-Ins" + raise RuntimeError(msg) + + if results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) + # NOTE: mmcv.batched_nms Ops does not support half precision bboxes + if bboxes.dtype != torch.float32: + bboxes = bboxes.float() + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[: cfg.max_per_img] + + # process masks + mask_logits = self._mask_predict_by_feat_single(mask_feat, results.kernels, results.priors) + mask_logits = F.interpolate(mask_logits.unsqueeze(0), scale_factor=stride, mode="bilinear") + + if rescale: + ori_h, ori_w = img_meta["ori_shape"][:2] + masks = self.mask_postprocess( + mask_logits, + math.ceil(mask_logits.shape[-1] * scale_factor[1]), + math.ceil(mask_logits.shape[-2] * scale_factor[0]), + threshold=cfg.mask_thr_binary, + )[..., :ori_h, :ori_w] + masks = masks.squeeze(0) + else: + masks = mask_logits.sigmoid().squeeze(0) + masks = masks > cfg.mask_thr_binary + results.masks = masks + else: + h, w = img_meta["ori_shape"][:2] if rescale else img_meta["img_shape"][:2] + results.masks = torch.zeros( + size=(results.bboxes.shape[0], h, w), + dtype=torch.bool, + device=results.bboxes.device, + ) + + return results + + +def _custom_mask_predict_by_feat_single( + self: CustomRTMDetInsSepBNHead, + mask_feat: torch.Tensor, + kernels: torch.Tensor, + priors: torch.Tensor, +) -> torch.Tensor: + """Decode mask with dynamic conv. + + Note: Prior Generator has cuda device set as default. + However, this would cause some problems on CPU only devices. + """ + num_inst = priors.shape[1] + batch_size = priors.shape[0] + hw = mask_feat.size()[-2:] + # NOTE: had to force to set device in prior generator + coord = self.prior_generator.single_level_grid_priors(hw, level_idx=0, device=mask_feat.device).to(mask_feat.device) + coord = coord.unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1) + priors = priors.unsqueeze(2) + points = priors[..., :2] + relative_coord = (points - coord).permute(0, 1, 3, 2) / (priors[..., 2:3] * 8) + relative_coord = relative_coord.reshape(batch_size, num_inst, 2, hw[0], hw[1]) + + mask_feat = torch.cat([relative_coord, mask_feat.unsqueeze(1).repeat(1, num_inst, 1, 1, 1)], dim=2) + weights, biases = _parse_dynamic_params(self, kernels) + + n_layers = len(weights) + x = mask_feat.flatten(0, 1).flatten(2) + for i, (weight, bias) in enumerate(zip(weights, biases)): + # replace dynamic conv with bmm + weight = weight.flatten(0, 1) # noqa: PLW2901 + bias = bias.flatten(0, 1).unsqueeze(2) # noqa: PLW2901 + x = torch.bmm(weight, x) + x = x + bias + if i < n_layers - 1: + x = x.clamp_(min=0) + return x.reshape(batch_size, num_inst, hw[0], hw[1]) + + +def _custom_nms_with_mask_static( + self: CustomRTMDetInsSepBNHead, + priors: torch.Tensor, + bboxes: torch.Tensor, + scores: torch.Tensor, + kernels: torch.Tensor, + mask_feats: torch.Tensor, + max_output_boxes_per_class: int = 1000, + iou_threshold: float = 0.5, + score_threshold: float = 0.05, + pre_top_k: int = -1, + keep_top_k: int = -1, + mask_thr_binary: float = 0.5, # noqa: ARG001 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Wrapper for `multiclass_nms` with ONNXRuntime. + + Note: + Compared with the original _nms_with_mask_static, this function + crops masks using RoIAlign and returns the cropped masks. + + Args: + self: The instance of `RTMDetInsHead`. + priors (Tensor): The prior boxes of shape [num_boxes, 4]. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + kernels (Tensor): The dynamic conv kernels. + mask_feats (Tensor): The mask feature. + max_output_boxes_per_class (int): Maximum number of output + boxes per class of nms. Defaults to 1000. + iou_threshold (float): IOU threshold of nms. Defaults to 0.5. + score_threshold (float): score threshold of nms. + Defaults to 0.05. + pre_top_k (int): Number of top K boxes to keep before nms. + Defaults to -1. + keep_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + mask_thr_binary (float): Binarization threshold for masks. + + Returns: + tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] + and `labels` of shape [N, num_det]. + """ + dets, labels, inds = multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=True, + ) + + batch_size = bboxes.shape[0] + batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1) + kernels = kernels[batch_inds, inds, :] + priors = priors.unsqueeze(0).repeat(batch_size, 1, 1) + priors = priors[batch_inds, inds, :] + mask_logits = _custom_mask_predict_by_feat_single(self, mask_feats, kernels, priors) + stride = self.prior_generator.strides[0][0] + mask_logits = F.interpolate(mask_logits, scale_factor=stride, mode="bilinear") + masks = mask_logits.sigmoid() + + batch_index = ( + torch.arange(dets.size(0), device=dets.device).float().view(-1, 1, 1).expand(dets.size(0), dets.size(1), 1) + ) + rois = torch.cat([batch_index, dets[..., :4]], dim=-1) + cropped_masks = self.roi_align(masks, rois[0]) + cropped_masks = cropped_masks[torch.arange(cropped_masks.size(0)), torch.arange(cropped_masks.size(0))] + cropped_masks = cropped_masks.unsqueeze(0) + return dets, labels, cropped_masks + + +@FUNCTION_REWRITER.register_rewriter( + func_name="otx.algo.instance_segmentation.heads.custom_rtmdet_ins_head.CustomRTMDetInsSepBNHead.predict_by_feat", +) +def rtmdet_ins_head__predict_by_feat( + self: CustomRTMDetInsSepBNHead, + cls_scores: list[torch.Tensor], + bbox_preds: list[torch.Tensor], + kernel_preds: list[torch.Tensor], + mask_feat: torch.Tensor, + score_factors: list[torch.Tensor] | None = None, # noqa: ARG001 + batch_img_metas: list[dict] | None = None, # noqa: ARG001 + cfg: ConfigDict | None = None, + rescale: bool = False, # noqa: ARG001 + with_nms: bool = True, # noqa: ARG001 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Rewrite `predict_by_feat` of `RTMDet-Ins` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx: Context that contains original meta information. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor, + where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch + size and the score between 0 and 1. The shape of the second + tensor in the tuple is (N, num_box), and each element + represents the class label of the corresponding box. + """ + if len(cls_scores) != len(bbox_preds): + msg = "The length of cls_scores and bbox_preds should be the same." + raise ValueError(msg) + device = cls_scores[0].device + cfg = self.test_cfg if cfg is None else cfg + batch_size = bbox_preds[0].shape[0] + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + mlvl_priors = self.prior_generator.grid_priors(featmap_sizes, device=device, with_stride=True) + + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, self.cls_out_channels) for cls_score in cls_scores + ] + flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) for bbox_pred in bbox_preds] + flatten_kernel_preds = [ + kernel_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_gen_params) for kernel_pred in kernel_preds + ] + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + _flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + flatten_kernel_preds = torch.cat(flatten_kernel_preds, dim=1) + priors = torch.cat(mlvl_priors) + tl_x = priors[..., 0] - _flatten_bbox_preds[..., 0] + tl_y = priors[..., 1] - _flatten_bbox_preds[..., 1] + br_x = priors[..., 0] + _flatten_bbox_preds[..., 2] + br_y = priors[..., 1] + _flatten_bbox_preds[..., 3] + bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1) + scores = flatten_cls_scores + + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get("iou_threshold", post_params.iou_threshold) + score_threshold = cfg.get("score_thr", post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get("max_per_img", post_params.keep_top_k) + mask_thr_binary = cfg.get("mask_thr_binary", 0.5) + + return _custom_nms_with_mask_static( + self, + priors, + bboxes, + scores, + flatten_kernel_preds, + mask_feat, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k, + keep_top_k, + mask_thr_binary, + ) diff --git a/src/otx/algo/instance_segmentation/mmconfigs/rtmdet_inst_tiny.yaml b/src/otx/algo/instance_segmentation/mmconfigs/rtmdet_inst_tiny.yaml index fecd05202d1..8abe02d6398 100644 --- a/src/otx/algo/instance_segmentation/mmconfigs/rtmdet_inst_tiny.yaml +++ b/src/otx/algo/instance_segmentation/mmconfigs/rtmdet_inst_tiny.yaml @@ -46,7 +46,7 @@ neck: type: SiLU inplace: true bbox_head: - type: RTMDetInsSepBNHead + type: CustomRTMDetInsSepBNHead num_classes: 80 in_channels: 96 stacked_convs: 2 diff --git a/src/otx/core/data/dataset/instance_segmentation.py b/src/otx/core/data/dataset/instance_segmentation.py index 77aa8057261..03be87d3be3 100644 --- a/src/otx/core/data/dataset/instance_segmentation.py +++ b/src/otx/core/data/dataset/instance_segmentation.py @@ -55,7 +55,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None: gt_masks.append(polygon_to_bitmap([annotation], *img_shape)[0]) # convert xywh to xyxy format - bboxes = np.array(gt_bboxes, dtype=np.float32) + bboxes = np.array(gt_bboxes, dtype=np.float32) if gt_bboxes else np.empty((0, 4)) bboxes[:, 2:] += bboxes[:, :2] masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool) diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index a29b3432f06..cba375a7352 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -443,7 +443,7 @@ def _customize_outputs( tv_tensors.BoundingBoxes( output.pred_instances.bboxes, format="XYXY", - canvas_size=output.img_shape, + canvas_size=output.ori_shape, ), ) labels.append(output.pred_instances.labels) diff --git a/src/otx/core/utils/mask_util.py b/src/otx/core/utils/mask_util.py index bd5b70da1b5..f4108fea1d3 100644 --- a/src/otx/core/utils/mask_util.py +++ b/src/otx/core/utils/mask_util.py @@ -51,7 +51,9 @@ def polygon_to_rle( list[dict]: List of RLE masks. """ polygons = [polygon.points for polygon in polygons] - return mask_utils.frPyObjects(polygons, height, width) + if len(polygons): + return mask_utils.frPyObjects(polygons, height, width) + return [] def encode_rle(mask: torch.Tensor) -> dict: diff --git a/src/otx/recipe/instance_segmentation/openvino_model.yaml b/src/otx/recipe/instance_segmentation/openvino_model.yaml index 63c2c4b3479..f4ed7dae823 100644 --- a/src/otx/recipe/instance_segmentation/openvino_model.yaml +++ b/src/otx/recipe/instance_segmentation/openvino_model.yaml @@ -35,4 +35,4 @@ overrides: image_color_channel: RGB data_format: coco_instances test_subset: - batch_size: 2 + batch_size: 64 diff --git a/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny.yaml b/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny.yaml index ad7db5e6431..b509b19078a 100644 --- a/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny.yaml +++ b/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny.yaml @@ -5,10 +5,11 @@ model: variant: tiny optimizer: - class_path: torch.optim.AdamW + class_path: torch.optim.SGD init_args: - lr: 0.004 - weight_decay: 0.05 + lr: 0.001 + momentum: 0.9 + weight_decay: 0.0001 scheduler: - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler @@ -30,7 +31,7 @@ callback_monitor: val/map_50 data: ../_base_/data/mmdet_base.yaml overrides: - precision: 32 # 16/"16-true" does not work + precision: 16 max_epochs: 100 gradient_clip_val: 35.0 data: @@ -39,7 +40,7 @@ overrides: include_polygons: true train_subset: batch_size: 4 - num_workers: 10 + num_workers: 8 transforms: - type: LoadImageFromFile backend_args: null @@ -61,6 +62,7 @@ overrides: - 0.5 - 2.0 keep_ratio: true + _scope_: mmdet - type: RandomCrop crop_size: - 640 @@ -91,8 +93,8 @@ overrides: - 1 - type: PackDetInputs val_subset: - batch_size: 2 - num_workers: 10 + batch_size: 1 + num_workers: 4 transforms: - type: LoadImageFromFile backend_args: null @@ -115,8 +117,8 @@ overrides: - img_shape - scale_factor test_subset: - batch_size: 2 - num_workers: 10 + batch_size: 1 + num_workers: 4 transforms: - type: LoadImageFromFile backend_args: null diff --git a/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny_tile.yaml b/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny_tile.yaml new file mode 100644 index 00000000000..5c138c7da8c --- /dev/null +++ b/src/otx/recipe/instance_segmentation/rtmdet_inst_tiny_tile.yaml @@ -0,0 +1,117 @@ +model: + class_path: otx.algo.instance_segmentation.rtmdet_inst.RTMDetInst + init_args: + num_classes: 80 + variant: tiny + +optimizer: + class_path: torch.optim.SGD + init_args: + lr: 0.001 + momentum: 0.9 + weight_decay: 0.0001 + +scheduler: + - class_path: otx.algo.schedulers.warmup_schedulers.LinearWarmupScheduler + init_args: + num_warmup_steps: 20 + - class_path: lightning.pytorch.cli.ReduceLROnPlateau + init_args: + mode: max + factor: 0.1 + patience: 9 + monitor: val/map_50 + min_lr: 4e-06 + +engine: + task: INSTANCE_SEGMENTATION + device: auto + +callback_monitor: val/map_50 + +data: ../_base_/data/mmdet_base.yaml +overrides: + precision: 32 + max_epochs: 100 + gradient_clip_val: 35.0 + data: + task: INSTANCE_SEGMENTATION + config: + tile_config: + enable_tiler: true + enable_adaptive_tiling: true + include_polygons: true + train_subset: + batch_size: 4 + num_workers: 8 + transforms: + - type: LoadImageFromFile + backend_args: null + - type: LoadAnnotations + with_bbox: true + with_mask: true + - type: Resize + scale: + - 640 + - 640 + keep_ratio: false + - type: Pad + size: + - 640 + - 640 + pad_val: 114 + - type: RandomFlip + prob: 0.5 + - type: PackDetInputs + val_subset: + batch_size: 1 + num_workers: 4 + transforms: + - type: LoadImageFromFile + backend_args: null + - type: Resize + scale: + - 640 + - 640 + keep_ratio: true + - type: Pad + size: + - 640 + - 640 + pad_val: 114 + - type: LoadAnnotations + with_bbox: true + with_mask: true + - type: PackDetInputs + meta_keys: + - img_id + - img_path + - ori_shape + - img_shape + - scale_factor + test_subset: + batch_size: 1 + num_workers: 4 + transforms: + - type: LoadImageFromFile + backend_args: null + - type: Resize + scale: + - 640 + - 640 + keep_ratio: true + - type: Pad + size: + - 640 + - 640 + pad_val: 114 + - type: LoadAnnotations + with_bbox: true + with_mask: true + - type: PackDetInputs + meta_keys: + - img_id + - img_path + - ori_shape + - img_shape + - scale_factor diff --git a/tests/integration/test_tiling.py b/tests/integration/test_tiling.py deleted file mode 100644 index 574d8db14eb..00000000000 --- a/tests/integration/test_tiling.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -from __future__ import annotations - -import numpy as np -import pytest -from datumaro import Dataset as DmDataset -from omegaconf import DictConfig, OmegaConf -from otx.core.config.data import ( - DataModuleConfig, - SubsetConfig, - TileConfig, - VisualPromptingConfig, -) -from otx.core.data.dataset.tile import OTXTileTransform -from otx.core.data.entity.detection import DetBatchDataEntity -from otx.core.data.entity.tile import TileBatchDetDataEntity -from otx.core.data.module import OTXDataModule -from otx.core.types.task import OTXTaskType - - -class TestOTXTiling: - @pytest.fixture() - def fxt_mmcv_det_transform_config(self) -> list[DictConfig]: - mmdet_base = OmegaConf.load("src/otx/recipe/_base_/data/mmdet_base.yaml") - return mmdet_base.config.train_subset.transforms - - @pytest.fixture() - def fxt_det_data_config(self, fxt_asset_dir, fxt_mmcv_det_transform_config) -> OTXDataModule: - data_root = fxt_asset_dir / "car_tree_bug" - - batch_size = 8 - num_workers = 0 - return DataModuleConfig( - data_format="coco_instances", - data_root=data_root, - train_subset=SubsetConfig( - subset_name="train", - batch_size=batch_size, - num_workers=num_workers, - transform_lib_type="MMDET", - transforms=fxt_mmcv_det_transform_config, - ), - val_subset=SubsetConfig( - subset_name="val", - batch_size=batch_size, - num_workers=num_workers, - transform_lib_type="MMDET", - transforms=fxt_mmcv_det_transform_config, - ), - test_subset=SubsetConfig( - subset_name="test", - batch_size=batch_size, - num_workers=num_workers, - transform_lib_type="MMDET", - transforms=fxt_mmcv_det_transform_config, - ), - tile_config=TileConfig(), - vpm_config=VisualPromptingConfig(), - ) - - def test_tile_transform(self): - dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") - first_item = next(iter(dataset), None) - height, width = first_item.media.data.shape[:2] - - rng = np.random.default_rng() - tile_size = rng.integers(low=100, high=500, size=(2,)) - overlap = rng.random(2) - threshold_drop_ann = rng.random() - tiled_dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") - tiled_dataset.transform( - OTXTileTransform, - tile_size=tile_size, - overlap=overlap, - threshold_drop_ann=threshold_drop_ann, - ) - - h_stride = max(int((1 - overlap[0]) * tile_size[0]), 1) - w_stride = max(int((1 - overlap[1]) * tile_size[1]), 1) - num_tile_rows = (height + h_stride - 1) // h_stride - num_tile_cols = (width + w_stride - 1) // w_stride - assert len(tiled_dataset) == (num_tile_rows * num_tile_cols * len(dataset)), "Incorrect number of tiles" - - def test_adaptive_tiling(self, fxt_det_data_config): - # Enable tile adapter - fxt_det_data_config.tile_config.enable_tiler = True - fxt_det_data_config.tile_config.enable_adaptive_tiling = True - tile_datamodule = OTXDataModule( - task=OTXTaskType.DETECTION, - config=fxt_det_data_config, - ) - tile_datamodule.prepare_data() - - assert tile_datamodule.config.tile_config.tile_size == (6750, 6750), "Tile size should be [6750, 6750]" - assert ( - pytest.approx(tile_datamodule.config.tile_config.overlap, rel=1e-3) == 0.03608 - ), "Overlap should be 0.03608" - assert tile_datamodule.config.tile_config.max_num_instances == 3, "Max num instances should be 3" - - def test_tile_sampler(self, fxt_det_data_config): - rng = np.random.default_rng() - - fxt_det_data_config.tile_config.enable_tiler = True - fxt_det_data_config.tile_config.enable_adaptive_tiling = False - fxt_det_data_config.tile_config.sampling_ratio = rng.random() - tile_datamodule = OTXDataModule( - task=OTXTaskType.DETECTION, - config=fxt_det_data_config, - ) - tile_datamodule.prepare_data() - sampled_count = max( - 1, - int(len(tile_datamodule._get_dataset("train")) * fxt_det_data_config.tile_config.sampling_ratio), - ) - - count = 0 - for batch in tile_datamodule.train_dataloader(): - count += batch.batch_size - assert isinstance(batch, DetBatchDataEntity) - - assert sampled_count == count, "Sampled count should be equal to the count of the dataloader batch size" - - def test_train_dataloader(self, fxt_det_data_config) -> None: - # Enable tile adapter - fxt_det_data_config.tile_config.enable_tiler = True - tile_datamodule = OTXDataModule( - task=OTXTaskType.DETECTION, - config=fxt_det_data_config, - ) - tile_datamodule.prepare_data() - for batch in tile_datamodule.train_dataloader(): - assert isinstance(batch, DetBatchDataEntity) - - def test_val_dataloader(self, fxt_det_data_config) -> None: - # Enable tile adapter - fxt_det_data_config.tile_config.enable_tiler = True - tile_datamodule = OTXDataModule( - task=OTXTaskType.DETECTION, - config=fxt_det_data_config, - ) - tile_datamodule.prepare_data() - for batch in tile_datamodule.val_dataloader(): - assert isinstance(batch, TileBatchDetDataEntity) - - def test_tile_merge(self): - pytest.skip("Not implemented yet") diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 00000000000..99b68c55dcd --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,42 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +"""Helper functions for tests.""" + +import numpy as np + + +def generate_random_bboxes( + image_width: int, + image_height: int, + num_boxes: int, + min_width: int = 10, + min_height: int = 10, +) -> np.ndarray: + """Generate random bounding boxes. + Parameters: + image_width (int): Width of the image. + image_height (int): Height of the image. + num_boxes (int): Number of bounding boxes to generate. + min_width (int): Minimum width of the bounding box. Default is 10. + min_height (int): Minimum height of the bounding box. Default is 10. + Returns: + ndarray: A NumPy array of shape (num_boxes, 4) representing bounding boxes in format (x_min, y_min, x_max, y_max). + """ + max_width = image_width - min_width + max_height = image_height - min_height + + bg = np.random.MT19937(seed=42) + rg = np.random.Generator(bg) + + x_min = rg.integers(0, max_width, size=num_boxes) + y_min = rg.integers(0, max_height, size=num_boxes) + x_max = x_min + rg.integers(min_width, image_width, size=num_boxes) + y_max = y_min + rg.integers(min_height, image_height, size=num_boxes) + + x_max[x_max > image_width] = image_width + y_max[y_max > image_height] = image_height + areas = (x_max - x_min) * (y_max - y_min) + bboxes = np.column_stack((x_min, y_min, x_max, y_max)) + return bboxes[areas > 0] diff --git a/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py b/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py new file mode 100644 index 00000000000..2ff04967585 --- /dev/null +++ b/tests/unit/algo/instance_segmentation/heads/test_custom_rtmdet_ins_head.py @@ -0,0 +1,85 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import torch +from mmengine.config import ConfigDict +from otx.algo.instance_segmentation.heads.custom_rtmdet_ins_head import CustomRTMDetInsSepBNHead +from otx.algo.instance_segmentation.rtmdet_inst import RTMDetInst +from otx.core.types.export import OTXExportFormatType + + +class TestCustomRTMDetInsSepBNHead: + def test_mask_pred(self, mocker) -> None: + num_samples = 1 + num_classes = 1 + test_cfg = ConfigDict( + nms_pre=100, + score_thr=0.0, + nms={"type": "nms", "iou_threshold": 1.0}, + max_per_img=100, + mask_thr_binary=0.0, + min_bbox_size=-1, + ) + s = 128 + img_metas = { + "img_shape": (s, s, 3), + "scale_factor": (1, 1), + "ori_shape": (s, s, 3), + } + mask_head = CustomRTMDetInsSepBNHead( + num_classes=num_classes, + in_channels=1, + num_prototypes=1, + num_dyconvs=1, + anchor_generator={ + "type": "MlvlPointGenerator", + "offset": 0, + "strides": (1,), + }, + bbox_coder={"type": "DistancePointBBoxCoder"}, + ) + cls_scores = [torch.rand((num_samples, num_classes, 14, 14))] + bbox_preds = [torch.rand((num_samples, 4, 14, 14))] + kernel_preds = [torch.rand((1, 32, 14, 14))] + mask_feat = torch.rand(num_samples, 1, 14, 14) + + mocker.patch.object( + CustomRTMDetInsSepBNHead, + "_mask_predict_by_feat_single", + return_value=torch.rand(100, 14, 14), + ) + + results = mask_head.predict_by_feat( + cls_scores=cls_scores, + bbox_preds=bbox_preds, + kernel_preds=kernel_preds, + mask_feat=mask_feat, + batch_img_metas=[img_metas], + cfg=test_cfg, + rescale=True, + ) + + mask_head._bbox_mask_post_process( + results[0], + mask_feat, + cfg=test_cfg, + ) + + mask_head._bbox_mask_post_process( + results[0], + mask_feat, + cfg=None, + ) + + def test_predict_by_feat_ov(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + lit_module = RTMDetInst(num_classes=1, variant="tiny") + exported_model_path = lit_module.export( + output_dir=Path(tmpdirname), + base_name="exported_model", + export_format=OTXExportFormatType.OPENVINO, + ) + Path.exists(exported_model_path) diff --git a/tests/unit/core/data/test_tiling.py b/tests/unit/core/data/test_tiling.py new file mode 100644 index 00000000000..68e9aaaa42a --- /dev/null +++ b/tests/unit/core/data/test_tiling.py @@ -0,0 +1,321 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import create_autospec + +import numpy as np +import pytest +import torch +from datumaro import Dataset as DmDataset +from omegaconf import DictConfig, OmegaConf +from otx.core.config.data import ( + DataModuleConfig, + SubsetConfig, + TileConfig, + VisualPromptingConfig, +) +from otx.core.data.dataset.tile import OTXTileTransform +from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity +from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity +from otx.core.data.entity.tile import TileBatchDetDataEntity +from otx.core.data.module import OTXDataModule +from otx.core.model.detection import OTXDetectionModel +from otx.core.model.instance_segmentation import OTXInstanceSegModel +from otx.core.types.task import OTXTaskType +from torchvision import tv_tensors + +from tests.test_helpers import generate_random_bboxes + + +class TestOTXTiling: + @pytest.fixture() + def mock_otx_det_model(self) -> OTXDetectionModel: + return create_autospec(OTXDetectionModel) + + @pytest.fixture() + def fxt_mmcv_det_transform_config(self) -> list[DictConfig]: + mmdet_base = OmegaConf.load("src/otx/recipe/_base_/data/mmdet_base.yaml") + return mmdet_base.config.train_subset.transforms + + @pytest.fixture() + def fxt_det_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModule: + data_root = Path(__file__).parent.parent.parent.parent / "assets" / "car_tree_bug" + + batch_size = 8 + num_workers = 0 + return DataModuleConfig( + data_format="coco_instances", + data_root=data_root, + train_subset=SubsetConfig( + subset_name="train", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + val_subset=SubsetConfig( + subset_name="val", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + test_subset=SubsetConfig( + subset_name="test", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + tile_config=TileConfig(), + vpm_config=VisualPromptingConfig(), + ) + + @pytest.fixture() + def fxt_instseg_data_config(self, fxt_mmcv_det_transform_config) -> OTXDataModule: + data_root = Path(__file__).parent.parent.parent.parent / "assets" / "car_tree_bug" + + for transform in fxt_mmcv_det_transform_config: + if transform.type == "LoadAnnotations": + transform.with_mask = True + + batch_size = 8 + num_workers = 0 + return DataModuleConfig( + data_format="coco_instances", + data_root=data_root, + train_subset=SubsetConfig( + subset_name="train", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + val_subset=SubsetConfig( + subset_name="val", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + test_subset=SubsetConfig( + subset_name="test", + batch_size=batch_size, + num_workers=num_workers, + transform_lib_type="MMDET", + transforms=fxt_mmcv_det_transform_config, + ), + tile_config=TileConfig(), + vpm_config=VisualPromptingConfig(), + ) + + def test_tile_transform(self): + dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") + first_item = next(iter(dataset), None) + height, width = first_item.media.data.shape[:2] + + rng = np.random.default_rng() + tile_size = rng.integers(low=100, high=500, size=(2,)) + overlap = rng.random(2) + threshold_drop_ann = rng.random() + tiled_dataset = DmDataset.import_from("tests/assets/car_tree_bug", format="coco_instances") + tiled_dataset.transform( + OTXTileTransform, + tile_size=tile_size, + overlap=overlap, + threshold_drop_ann=threshold_drop_ann, + ) + + h_stride = max(int((1 - overlap[0]) * tile_size[0]), 1) + w_stride = max(int((1 - overlap[1]) * tile_size[1]), 1) + num_tile_rows = (height + h_stride - 1) // h_stride + num_tile_cols = (width + w_stride - 1) // w_stride + assert len(tiled_dataset) == (num_tile_rows * num_tile_cols * len(dataset)), "Incorrect number of tiles" + + def test_adaptive_tiling(self, fxt_det_data_config): + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + fxt_det_data_config.tile_config.enable_adaptive_tiling = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + + assert tile_datamodule.config.tile_config.tile_size == (6750, 6750), "Tile size should be [6750, 6750]" + assert ( + pytest.approx(tile_datamodule.config.tile_config.overlap, rel=1e-3) == 0.03608 + ), "Overlap should be 0.03608" + assert tile_datamodule.config.tile_config.max_num_instances == 3, "Max num instances should be 3" + + def test_tile_sampler(self, fxt_det_data_config): + rng = np.random.default_rng() + + fxt_det_data_config.tile_config.enable_tiler = True + fxt_det_data_config.tile_config.enable_adaptive_tiling = False + fxt_det_data_config.tile_config.sampling_ratio = rng.random() + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + sampled_count = max( + 1, + int(len(tile_datamodule._get_dataset("train")) * fxt_det_data_config.tile_config.sampling_ratio), + ) + + count = 0 + for batch in tile_datamodule.train_dataloader(): + count += batch.batch_size + assert isinstance(batch, DetBatchDataEntity) + + assert sampled_count == count, "Sampled count should be equal to the count of the dataloader batch size" + + def test_train_dataloader(self, fxt_det_data_config) -> None: + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + for batch in tile_datamodule.train_dataloader(): + assert isinstance(batch, DetBatchDataEntity) + + def test_val_dataloader(self, fxt_det_data_config) -> None: + # Enable tile adapter + fxt_det_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + assert isinstance(batch, TileBatchDetDataEntity) + + def test_det_tile_merge(self, fxt_det_data_config): + def dummy_forward(x: DetBatchDataEntity) -> DetBatchPredEntity: + """Dummy forward function for testing. + + This function creates random bounding boxes for each image in the batch. + Args: + x (DetBatchDataEntity): Input batch data entity. + + Returns: + DetBatchPredEntity: Output batch prediction entity. + """ + bboxes = [] + labels = [] + scores = [] + for img_info in x.imgs_info: + img_h, img_w = img_info.ori_shape + img_bboxes = generate_random_bboxes( + image_width=img_w, + image_height=img_h, + num_boxes=100, + ) + bboxes.append( + tv_tensors.BoundingBoxes( + img_bboxes, + canvas_size=img_info.ori_shape, + format=tv_tensors.BoundingBoxFormat.XYXY, + dtype=torch.float64, + ), + ) + labels.append( + torch.LongTensor(len(img_bboxes)).random_(3), + ) + scores.append( + torch.rand(len(img_bboxes), dtype=torch.float64), + ) + + return DetBatchPredEntity( + batch_size=x.batch_size, + images=x.images, + imgs_info=x.imgs_info, + scores=scores, + bboxes=bboxes, + labels=labels, + ) + + model = OTXDetectionModel(num_classes=3) + fxt_det_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.DETECTION, + config=fxt_det_data_config, + ) + model.forward = dummy_forward + + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + model.forward_tiles(batch) + + def test_instseg_tile_merge(self, fxt_instseg_data_config): + def dummy_forward(x: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: + """Dummy forward function for testing. + + This function creates random bounding boxes/masks for each image in the batch. + Args: + x (InstanceSegBatchDataEntity): Input batch data entity. + + Returns: + InstanceSegBatchPredEntity: Output batch prediction entity. + """ + bboxes = [] + labels = [] + scores = [] + masks = [] + for img_info in x.imgs_info: + img_h, img_w = img_info.ori_shape + img_bboxes = generate_random_bboxes( + image_width=img_w, + image_height=img_h, + num_boxes=100, + ) + bboxes.append( + tv_tensors.BoundingBoxes( + img_bboxes, + canvas_size=img_info.ori_shape, + format=tv_tensors.BoundingBoxFormat.XYXY, + dtype=torch.float64, + ), + ) + labels.append( + torch.LongTensor(len(img_bboxes)).random_(3), + ) + scores.append( + torch.rand(len(img_bboxes), dtype=torch.float64), + ) + masks.append( + tv_tensors.Mask( + torch.randint(0, 2, (len(img_bboxes), img_h, img_w)), + dtype=torch.bool, + ), + ) + + return InstanceSegBatchPredEntity( + batch_size=x.batch_size, + images=x.images, + imgs_info=x.imgs_info, + scores=scores, + bboxes=bboxes, + masks=masks, + labels=labels, + polygons=x.polygons, + ) + + model = OTXInstanceSegModel(num_classes=3) + fxt_instseg_data_config.tile_config.enable_tiler = True + tile_datamodule = OTXDataModule( + task=OTXTaskType.INSTANCE_SEGMENTATION, + config=fxt_instseg_data_config, + ) + model.forward = dummy_forward + + tile_datamodule.prepare_data() + for batch in tile_datamodule.val_dataloader(): + model.forward_tiles(batch)