From 0288b36913048929c8a8a21011c65b429e895581 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 13:19:24 +0900 Subject: [PATCH] Replace mmcv nms to torchvision nms --- src/otx/algo/detection/heads/base_head.py | 3 +- src/otx/algo/detection/ops/__init__.py | 4 + src/otx/algo/detection/ops/nms.py | 221 ++++++++++++++++++++++ 3 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 src/otx/algo/detection/ops/__init__.py create mode 100644 src/otx/algo/detection/ops/nms.py diff --git a/src/otx/algo/detection/heads/base_head.py b/src/otx/algo/detection/heads/base_head.py index b513d90c4b0..89a883bb620 100644 --- a/src/otx/algo/detection/heads/base_head.py +++ b/src/otx/algo/detection/heads/base_head.py @@ -10,13 +10,14 @@ from typing import TYPE_CHECKING import torch -from mmcv.ops import batched_nms from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances from mmdet.structures.bbox import cat_boxes, get_box_tensor, get_box_wh, scale_boxes from mmengine.model import constant_init from mmengine.structures import InstanceData from torch import Tensor, nn +from otx.algo.detection.ops.nms import batched_nms + if TYPE_CHECKING: from mmdet.structures import SampleList from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig diff --git a/src/otx/algo/detection/ops/__init__.py b/src/otx/algo/detection/ops/__init__.py new file mode 100644 index 00000000000..8bed33ebe94 --- /dev/null +++ b/src/otx/algo/detection/ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Custom operation implementations for detection task.""" diff --git a/src/otx/algo/detection/ops/nms.py b/src/otx/algo/detection/ops/nms.py new file mode 100644 index 00000000000..968b0ec3f1a --- /dev/null +++ b/src/otx/algo/detection/ops/nms.py @@ -0,0 +1,221 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""NMS implementations for detection task.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from torch import Tensor +from torchvision.ops.boxes import nms as torch_nms + + +# This class is from NMSop in mmcv and slightly modified +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +class NMSop(torch.autograd.Function): + """NMS operation.""" + + @staticmethod + def forward( + ctx: Any, # noqa: ARG004, ANN401 + bboxes: Tensor, + scores: Tensor, + iou_threshold: float, + offset: int, # noqa: ARG004 + score_threshold: float, + max_num: int, + ) -> Tensor: + """Forward function.""" + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = torch.nonzero(valid_mask, as_tuple=False).squeeze(dim=1) + inds = torch_nms(bboxes, scores, iou_threshold) + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + +# This method is from nms in mmcv +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +def nms( + boxes: Tensor | np.ndarray, + scores: Tensor | np.ndarray, + iou_threshold: float, + offset: int = 0, + score_threshold: float = 0, + max_num: int = -1, +) -> tuple[Tensor | np.ndarray, Tensor | np.ndarray]: + """Dispatch to either CPU or GPU NMS implementations. + + The input can be either torch tensor or numpy array. GPU NMS will be used + if the input is gpu tensor, otherwise CPU NMS + will be used. The returned type will always be the same as inputs. + + Arguments: + boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4). + scores (torch.Tensor or np.ndarray): scores in shape (N, ). + iou_threshold (float): IoU threshold for NMS. + offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset). + score_threshold (float): score threshold for NMS. + max_num (int): maximum number of boxes after NMS. + + Returns: + tuple: kept dets (boxes and scores) and indice, which always have + the same data type as the input. + + Example: + >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9], + >>> [49.3, 32.9, 51.0, 35.3], + >>> [49.2, 31.8, 51.0, 35.4], + >>> [35.1, 11.5, 39.1, 15.7], + >>> [35.6, 11.8, 39.3, 14.2], + >>> [35.3, 11.5, 39.9, 14.5], + >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32) + >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\ + dtype=np.float32) + >>> iou_threshold = 0.6 + >>> dets, inds = nms(boxes, scores, iou_threshold) + >>> assert len(inds) == len(dets) == 3 + """ + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = torch.from_numpy(boxes) + if isinstance(scores, np.ndarray): + scores = torch.from_numpy(scores) + + inds = NMSop.apply(boxes, scores, iou_threshold, offset, score_threshold, max_num) + dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1) + if is_numpy: + dets = dets.cpu().numpy() + inds = inds.cpu().numpy() + return dets, inds + + +# This method is from batched_nms in mmcv +# https://github.com/open-mmlab/mmcv/blob/265531fa9fe9e071c7d80df549d680ed257d9a16/mmcv/ops/nms.py +def batched_nms( + boxes: Tensor, + scores: Tensor, + idxs: Tensor, + nms_cfg: dict | None = None, + class_agnostic: bool = False, +) -> tuple[Tensor, Tensor]: + r"""Performs non-maximum suppression in a batched fashion. + + Modified from `torchvision/ops/boxes.py#L39 + `_. + In order to perform NMS independently per class, we add an offset to all + the boxes. The offset is dependent only on the class idx, and is large + enough so that boxes from different classes do not overlap. + + Note: + In v1.4.1 and later, ``batched_nms`` supports skipping the NMS and + returns sorted raw results when `nms_cfg` is None. + + Args: + boxes (torch.Tensor): boxes in shape (N, 4) or (N, 5). + scores (torch.Tensor): scores in shape (N, ). + idxs (torch.Tensor): each index value correspond to a bbox cluster, + and NMS will not be applied between elements of different idxs, + shape (N, ). + nms_cfg (dict | optional): Supports skipping the nms when `nms_cfg` + is None, otherwise it should specify nms type and other + parameters like `iou_thr`. Possible keys includes the following. + + - iou_threshold (float): IoU threshold used for NMS. + - split_thr (float): threshold number of boxes. In some cases the + number of boxes is large (e.g., 200k). To avoid OOM during + training, the users could set `split_thr` to a small value. + If the number of boxes is greater than the threshold, it will + perform NMS on each group of boxes separately and sequentially. + Defaults to 10000. + class_agnostic (bool): if true, nms is class agnostic, + i.e. IoU thresholding happens over all boxes, + regardless of the predicted class. Defaults to False. + + Returns: + tuple: kept dets and indice. + + - boxes (Tensor): Bboxes with score after nms, has shape + (num_bboxes, 5). last dimension 5 arrange as + (x1, y1, x2, y2, score) + - keep (Tensor): The indices of remaining boxes in input + boxes. + """ + # skip nms when nms_cfg is None + if nms_cfg is None: + scores, inds = scores.sort(descending=True) + boxes = boxes[inds] + return torch.cat([boxes, scores[:, None]], -1), inds + + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop("class_agnostic", class_agnostic) + if class_agnostic: + boxes_for_nms = boxes + # When using rotated boxes, only apply offsets on center. + elif boxes.size(-1) == 5: + # Strictly, the maximum coordinates of the rotating box + # (x,y,w,h,a) should be calculated by polygon coordinates. + # But the conversion from rotated box to polygon will + # slow down the speed. + # So we use max(x,y) + max(w,h) as max coordinate + # which is larger than polygon max coordinate + # max(x1, y1, x2, y2,x3, y3, x4, y4) + max_coordinate = boxes[..., :2].max() + boxes[..., 2:4].max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_ctr_for_nms = boxes[..., :2] + offsets[:, None] + boxes_for_nms = torch.cat([boxes_ctr_for_nms, boxes[..., 2:5]], dim=-1) + else: + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_for_nms = boxes + offsets[:, None] + + nms_op = nms_cfg_.pop("type", "nms") + if isinstance(nms_op, str): + nms_op = eval(nms_op) # noqa: S307, PGH001 + + split_thr = nms_cfg_.pop("split_thr", 10000) + # Won't split to multiple nms nodes when exporting to onnx + if boxes_for_nms.shape[0] < split_thr: + dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + + # This assumes `dets` has arbitrary dimensions where + # the last dimension is score. + # Currently it supports bounding boxes [x1, y1, x2, y2, score] or + # rotated boxes [cx, cy, w, h, angle_radian, score]. + + scores = dets[:, -1] + else: + max_num = nms_cfg_.pop("max_num", -1) + total_mask = scores.new_zeros(scores.size(), dtype=torch.bool) + # Some type of nms would reweight the score, such as SoftNMS + scores_after_nms = scores.new_zeros(scores.size()) + for idx in torch.unique(idxs): + mask = (idxs == idx).nonzero(as_tuple=False).view(-1) + dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_) + total_mask[mask[keep]] = True + scores_after_nms[mask[keep]] = dets[:, -1] + keep = total_mask.nonzero(as_tuple=False).view(-1) + + scores, inds = scores_after_nms[keep].sort(descending=True) + keep = keep[inds] + boxes = boxes[keep] + + if max_num > 0: + keep = keep[:max_num] + boxes = boxes[:max_num] + scores = scores[:max_num] + + boxes = torch.cat([boxes, scores[:, None]], -1) + return boxes, keep