From fe130b556db9be69c9f5c0802f3dd67eeac9d72f Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 14:23:19 +0900 Subject: [PATCH 1/5] Remove BaseBoxes --- src/otx/algo/detection/heads/anchor_head.py | 11 +- src/otx/algo/detection/heads/base_sampler.py | 13 +- .../detection/heads/delta_xywh_bbox_coder.py | 16 +- .../algo/detection/heads/iou2d_calculator.py | 10 +- .../algo/detection/heads/max_iou_assigner.py | 2 +- src/otx/algo/detection/structures/__init__.py | 3 + .../algo/detection/structures/structures.py | 209 ++++++++++++++++++ 7 files changed, 234 insertions(+), 30 deletions(-) create mode 100644 src/otx/algo/detection/structures/__init__.py create mode 100644 src/otx/algo/detection/structures/structures.py diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index ab3ec9e9d8c..cba4104dd13 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -12,7 +12,6 @@ from mmdet.models.task_modules.prior_generators import anchor_inside_flags from mmdet.models.utils import images_to_levels, multi_apply, unmap from mmdet.registry import MODELS, TASK_UTILS -from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor from mmengine.structures import InstanceData from torch import Tensor, nn @@ -199,7 +198,7 @@ def get_anchors( def _get_targets_single( self, - flat_anchors: Tensor | BaseBoxes, + flat_anchors: Tensor, valid_flags: Tensor, gt_instances: InstanceData, img_meta: dict, @@ -209,7 +208,7 @@ def _get_targets_single( """Compute regression and classification targets for anchors in a single image. Args: - flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors + flat_anchors (Tensor): Multi-level anchors of the image, which are concatenated into a single tensor or box type of shape (num_anchors, 4) valid_flags (Tensor): Multi level valid flags of the image, @@ -277,7 +276,6 @@ def _get_targets_single( pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_priors, sampling_result.pos_gt_bboxes) else: pos_bbox_targets = sampling_result.pos_gt_bboxes - pos_bbox_targets = get_box_tensor(pos_bbox_targets) bbox_targets[pos_inds, :] = pos_bbox_targets bbox_weights[pos_inds, :] = 1.0 @@ -364,7 +362,7 @@ def get_targets( concat_anchor_list = [] concat_valid_flag_list = [] for i in range(num_imgs): - concat_anchor_list.append(cat_boxes(anchor_list[i])) + concat_anchor_list.append(torch.cat(anchor_list[i])) concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) # compute targets for each image @@ -455,7 +453,6 @@ def loss_by_feat_single( # decodes the already encoded coordinates to absolute format. anchors = anchors.reshape(-1, anchors.size(-1)) bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) - bbox_pred = get_box_tensor(bbox_pred) loss_bbox = self.loss_bbox(bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor) return loss_cls, loss_bbox @@ -504,7 +501,7 @@ def loss_by_feat( # anchor number of multi levels num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] # concat all level anchors and flags to a single tensor - concat_anchor_list = [cat_boxes(anchor) for anchor in anchor_list] + concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list] all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors) losses_cls, losses_bbox = multi_apply( diff --git a/src/otx/algo/detection/heads/base_sampler.py b/src/otx/algo/detection/heads/base_sampler.py index 4fe04e131c4..8f8299732cc 100644 --- a/src/otx/algo/detection/heads/base_sampler.py +++ b/src/otx/algo/detection/heads/base_sampler.py @@ -4,11 +4,11 @@ from abc import ABCMeta, abstractmethod import torch -from mmdet.models.task_modules.assigners import AssignResult from mmdet.models.task_modules.samplers.sampling_result import SamplingResult -from mmdet.structures.bbox import BaseBoxes, cat_boxes from mmengine.structures import InstanceData +from otx.algo.detection.structures.structures import AssignResult + class BaseSampler(metaclass=ABCMeta): """Base class of samplers. @@ -101,13 +101,8 @@ def sample( gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: - # When `gt_bboxes` and `priors` are all box type, convert - # `gt_bboxes` type to `priors` type. - if isinstance(gt_bboxes, BaseBoxes) and isinstance(priors, BaseBoxes): - gt_bboxes_ = gt_bboxes.convert_to(type(priors)) - else: - gt_bboxes_ = gt_bboxes - priors = cat_boxes([gt_bboxes_, priors], dim=0) + gt_bboxes_ = gt_bboxes + priors = torch.cat([gt_bboxes_, priors], dim=0) assign_result.add_gt_(gt_labels) gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8) gt_flags = torch.cat([gt_ones, gt_flags]) diff --git a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py index 126b5de4eef..79331cf0f4d 100644 --- a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py +++ b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py @@ -6,7 +6,7 @@ import numpy as np import torch -from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor +from mmdet.structures.bbox import HorizontalBoxes, get_box_tensor from torch import Tensor @@ -51,13 +51,13 @@ def __init__( self.add_ctr_clamp = add_ctr_clamp self.ctr_clamp = ctr_clamp - def encode(self, bboxes: Tensor | BaseBoxes, gt_bboxes: Tensor | BaseBoxes) -> Tensor: + def encode(self, bboxes: Tensor, gt_bboxes: Tensor) -> Tensor: """Get box regression transformation deltas that can be used to transform the bboxes into the gt_bboxes. Args: - bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes, + bboxes (torch.Tensor): Source boxes, e.g., object proposals. - gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the + gt_bboxes (torch.Tensor): Target of the transformation, e.g., ground-truth boxes. Returns: @@ -69,15 +69,15 @@ def encode(self, bboxes: Tensor | BaseBoxes, gt_bboxes: Tensor | BaseBoxes) -> T def decode( self, - bboxes: Tensor | BaseBoxes, + bboxes: Tensor, pred_bboxes: Tensor, max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None, wh_ratio_clip: float = 16 / 1000, - ) -> Tensor | BaseBoxes: + ) -> Tensor: """Apply transformation `pred_bboxes` to `boxes`. Args: - bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape + bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) pred_bboxes (Tensor): Encoded offsets with respect to each roi. Has shape (B, N, num_classes * 4) or (B, N, 4) or @@ -92,7 +92,7 @@ def decode( width and height. Returns: - Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes. + torch.Tensor: Decoded boxes. """ bboxes = get_box_tensor(bboxes) decoded_bboxes = delta2bbox( diff --git a/src/otx/algo/detection/heads/iou2d_calculator.py b/src/otx/algo/detection/heads/iou2d_calculator.py index fd46b436b5d..76d36bb28a7 100644 --- a/src/otx/algo/detection/heads/iou2d_calculator.py +++ b/src/otx/algo/detection/heads/iou2d_calculator.py @@ -5,7 +5,7 @@ from __future__ import annotations import torch -from mmdet.structures.bbox import BaseBoxes, bbox_overlaps, get_box_tensor +from mmdet.structures.bbox import bbox_overlaps, get_box_tensor # This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at: @@ -19,18 +19,18 @@ def __init__(self, scale: float = 1.0, dtype: str | None = None): def __call__( self, - bboxes1: torch.Tensor | BaseBoxes, - bboxes2: torch.Tensor | BaseBoxes, + bboxes1: torch.Tensor, + bboxes2: torch.Tensor, mode: str = "iou", is_aligned: bool = False, ) -> torch.Tensor: """Calculate IoU between 2D bboxes. Args: - bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + bboxes1 (Tensor): bboxes have shape (m, 4) in format, or shape (m, 5) in format. - bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) + bboxes2 (Tensor): bboxes have shape (m, 4) in format, shape (m, 5) in format, or be empty. If ``is_aligned `` is ``True``, then m and n must be equal. diff --git a/src/otx/algo/detection/heads/max_iou_assigner.py b/src/otx/algo/detection/heads/max_iou_assigner.py index 7bcbb0353ab..9097eecfc7a 100644 --- a/src/otx/algo/detection/heads/max_iou_assigner.py +++ b/src/otx/algo/detection/heads/max_iou_assigner.py @@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Callable import torch -from mmdet.models.task_modules.assigners.assign_result import AssignResult from torch import Tensor from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D +from otx.algo.detection.structures.structures import AssignResult if TYPE_CHECKING: from mmengine.structures import InstanceData diff --git a/src/otx/algo/detection/structures/__init__.py b/src/otx/algo/detection/structures/__init__.py new file mode 100644 index 00000000000..a8647baa335 --- /dev/null +++ b/src/otx/algo/detection/structures/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Data structures for detection task.""" diff --git a/src/otx/algo/detection/structures/structures.py b/src/otx/algo/detection/structures/structures.py new file mode 100644 index 00000000000..4bab3c4f79c --- /dev/null +++ b/src/otx/algo/detection/structures/structures.py @@ -0,0 +1,209 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Data structures for detection task.""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch import Tensor + + +class AssignResult: + """Stores assignments between predicted and truth boxes. + + Attributes: + num_gts (int): the number of truth boxes considered when computing this + assignment + gt_inds (Tensor): for each predicted box indicates the 1-based + index of the assigned truth box. 0 means unassigned and -1 means + ignore. + max_overlaps (Tensor): the iou between the predicted box and its + assigned truth box. + labels (Tensor): If specified, for each predicted box + indicates the category label of the assigned truth box. + + Example: + >>> # An assign result between 4 predicted boxes and 9 true boxes + >>> # where only two boxes were assigned. + >>> num_gts = 9 + >>> max_overlaps = torch.LongTensor([0, .5, .9, 0]) + >>> gt_inds = torch.LongTensor([-1, 1, 2, 0]) + >>> labels = torch.LongTensor([0, 3, 4, 0]) + >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + + >>> # Force addition of gt labels (when adding gt as proposals) + >>> new_labels = torch.LongTensor([3, 4, 5]) + >>> self.add_gt_(new_labels) + >>> print(str(self)) # xdoctest: +IGNORE_WANT + + """ + + def __init__(self, num_gts: int, gt_inds: Tensor, max_overlaps: Tensor, labels: Tensor) -> None: + self.num_gts = num_gts + self.gt_inds = gt_inds + self.max_overlaps = max_overlaps + self.labels = labels + # Interface for possible user-defined properties + self._extra_properties: dict[str, Any] = {} + + @property + def num_preds(self) -> int: + """int: the number of predictions in this assignment.""" + return len(self.gt_inds) + + def set_extra_property(self, key: str, value: Any) -> None: # noqa: ANN401 + """Set user-defined new property.""" + self._extra_properties[key] = value + + def get_extra_property(self, key: str) -> Any: # noqa: ANN401 + """Get user-defined property.""" + return self._extra_properties.get(key, None) + + @property + def info(self) -> dict: + """Return a dictionary of info about the object.""" + basic_info = { + "num_gts": self.num_gts, + "num_preds": self.num_preds, + "gt_inds": self.gt_inds, + "max_overlaps": self.max_overlaps, + "labels": self.labels, + } + basic_info.update(self._extra_properties) + return basic_info + + def add_gt_(self, gt_labels: Tensor) -> None: + """Add ground truth as assigned results. + + Args: + gt_labels (torch.Tensor): Labels of gt boxes + """ + self_inds = torch.arange(1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + + self.max_overlaps = torch.cat([self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + + self.labels = torch.cat([gt_labels, self.labels]) + + +class SamplingResult: + """Bbox sampling result. + + Args: + pos_inds (Tensor): Indices of positive samples. + neg_inds (Tensor): Indices of negative samples. + priors (Tensor): The priors can be anchors or points, + or the bboxes predicted by the previous stage. + gt_bboxes (Tensor): Ground truth of bboxes. + assign_result (:obj:`AssignResult`): Assigning results. + gt_flags (Tensor): The Ground truth flags. + avg_factor_with_neg (bool): If True, ``avg_factor`` equal to + the number of total priors; Otherwise, it is the number of + positive priors. Defaults to True. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + + def __init__( + self, + pos_inds: Tensor, + neg_inds: Tensor, + priors: Tensor, + gt_bboxes: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True, + ) -> None: + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.num_pos = max(pos_inds.numel(), 1) + self.num_neg = max(neg_inds.numel(), 1) + self.avg_factor_with_neg = avg_factor_with_neg + self.avg_factor = self.num_pos + self.num_neg if avg_factor_with_neg else self.num_pos + self.pos_priors = priors[pos_inds] + self.neg_priors = priors[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + self.pos_gt_labels = assign_result.labels[pos_inds] + box_dim = 4 + if gt_bboxes.numel() == 0: + self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, box_dim) + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] + + @property + def priors(self) -> Tensor: + """torch.Tensor: concatenated positive and negative priors.""" + return torch.cat([self.pos_priors, self.neg_priors]) + + @property + def bboxes(self) -> Tensor: + """torch.Tensor: concatenated positive and negative boxes.""" + return self.priors + + @property + def pos_bboxes(self) -> Tensor: + """Return positive box pairs.""" + return self.pos_priors + + @property + def neg_bboxes(self) -> Tensor: + """Return negative box pairs.""" + return self.neg_priors + + def to(self, device: str | torch.device) -> SamplingResult: + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + @property + def info(self) -> dict: + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_priors": self.pos_priors, + "neg_priors": self.neg_priors, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + "num_pos": self.num_pos, + "num_neg": self.num_neg, + "avg_factor": self.avg_factor, + } From c7810c167f43bcacb9b735d0e41c7ae79bdecdda Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 15:06:50 +0900 Subject: [PATCH 2/5] Migrate losses --- src/otx/algo/detection/heads/base_sampler.py | 23 +-- .../algo/detection/heads/custom_ssd_head.py | 2 +- .../detection/heads/delta_xywh_bbox_coder.py | 11 +- .../algo/detection/heads/iou2d_calculator.py | 5 +- .../algo/detection/heads/max_iou_assigner.py | 2 +- .../algo/detection/losses/weighted_loss.py | 155 +++++++++++++++ .../{structures => utils}/__init__.py | 0 src/otx/algo/detection/utils/bbox_overlaps.py | 184 ++++++++++++++++++ .../{structures => utils}/structures.py | 37 ---- 9 files changed, 346 insertions(+), 73 deletions(-) create mode 100644 src/otx/algo/detection/losses/weighted_loss.py rename src/otx/algo/detection/{structures => utils}/__init__.py (100%) create mode 100644 src/otx/algo/detection/utils/bbox_overlaps.py rename src/otx/algo/detection/{structures => utils}/structures.py (78%) diff --git a/src/otx/algo/detection/heads/base_sampler.py b/src/otx/algo/detection/heads/base_sampler.py index 8f8299732cc..1c75c310466 100644 --- a/src/otx/algo/detection/heads/base_sampler.py +++ b/src/otx/algo/detection/heads/base_sampler.py @@ -4,10 +4,9 @@ from abc import ABCMeta, abstractmethod import torch -from mmdet.models.task_modules.samplers.sampling_result import SamplingResult from mmengine.structures import InstanceData -from otx.algo.detection.structures.structures import AssignResult +from otx.algo.detection.utils.structures import AssignResult, SamplingResult class BaseSampler(metaclass=ABCMeta): @@ -72,26 +71,6 @@ def sample( Returns: :obj:`SamplingResult`: Sampling result. - - Example: - >>> from mmengine.structures import InstanceData - >>> from mmdet.models.task_modules.samplers import RandomSampler, - >>> from mmdet.models.task_modules.assigners import AssignResult - >>> from mmdet.models.task_modules.samplers. - ... sampling_result import ensure_rng, random_boxes - >>> rng = ensure_rng(None) - >>> assign_result = AssignResult.random(rng=rng) - >>> pred_instances = InstanceData() - >>> pred_instances.priors = random_boxes(assign_result.num_preds, - ... rng=rng) - >>> gt_instances = InstanceData() - >>> gt_instances.bboxes = random_boxes(assign_result.num_gts, - ... rng=rng) - >>> gt_instances.labels = torch.randint( - ... 0, 5, (assign_result.num_gts,), dtype=torch.long) - >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, - >>> add_gt_as_proposals=False) - >>> self = self.sample(assign_result, pred_instances, gt_instances) """ gt_bboxes = gt_instances.bboxes priors = pred_instances.priors diff --git a/src/otx/algo/detection/heads/custom_ssd_head.py b/src/otx/algo/detection/heads/custom_ssd_head.py index 22fb788af4c..d7e963099da 100644 --- a/src/otx/algo/detection/heads/custom_ssd_head.py +++ b/src/otx/algo/detection/heads/custom_ssd_head.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.losses import smooth_l1_loss from mmdet.models.utils import multi_apply from mmdet.registry import MODELS from torch import Tensor, nn @@ -17,6 +16,7 @@ from otx.algo.detection.heads.custom_anchor_generator import SSDAnchorGeneratorClustered from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner +from otx.algo.detection.losses.weighted_loss import smooth_l1_loss if TYPE_CHECKING: from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList diff --git a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py index 79331cf0f4d..a46e838e06e 100644 --- a/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py +++ b/src/otx/algo/detection/heads/delta_xywh_bbox_coder.py @@ -6,7 +6,6 @@ import numpy as np import torch -from mmdet.structures.bbox import HorizontalBoxes, get_box_tensor from torch import Tensor @@ -44,6 +43,7 @@ def __init__( ctr_clamp: int = 32, ) -> None: self.encode_size = encode_size + # TODO(Jaeguk): use_box_type should be deprecated. self.use_box_type = use_box_type self.means = target_means self.stds = target_stds @@ -63,8 +63,6 @@ def encode(self, bboxes: Tensor, gt_bboxes: Tensor) -> Tensor: Returns: torch.Tensor: Box transformation deltas """ - bboxes = get_box_tensor(bboxes) - gt_bboxes = get_box_tensor(gt_bboxes) return bbox2delta(bboxes, gt_bboxes, self.means, self.stds) def decode( @@ -94,8 +92,7 @@ def decode( Returns: torch.Tensor: Decoded boxes. """ - bboxes = get_box_tensor(bboxes) - decoded_bboxes = delta2bbox( + return delta2bbox( bboxes, pred_bboxes, self.means, @@ -107,10 +104,6 @@ def decode( self.ctr_clamp, ) - if self.use_box_type: - decoded_bboxes = HorizontalBoxes(decoded_bboxes) - return decoded_bboxes - def bbox2delta( proposals: Tensor, diff --git a/src/otx/algo/detection/heads/iou2d_calculator.py b/src/otx/algo/detection/heads/iou2d_calculator.py index 76d36bb28a7..bad8a5ea094 100644 --- a/src/otx/algo/detection/heads/iou2d_calculator.py +++ b/src/otx/algo/detection/heads/iou2d_calculator.py @@ -5,7 +5,8 @@ from __future__ import annotations import torch -from mmdet.structures.bbox import bbox_overlaps, get_box_tensor + +from otx.algo.detection.utils.bbox_overlaps import bbox_overlaps # This class and its supporting functions below lightly adapted from the mmdet BboxOverlaps2D available at: @@ -43,8 +44,6 @@ def __call__( Returns: Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) """ - bboxes1 = get_box_tensor(bboxes1) - bboxes2 = get_box_tensor(bboxes2) if bboxes2.size(-1) == 5: bboxes2 = bboxes2[..., :4] if bboxes1.size(-1) == 5: diff --git a/src/otx/algo/detection/heads/max_iou_assigner.py b/src/otx/algo/detection/heads/max_iou_assigner.py index 9097eecfc7a..f95f44585a1 100644 --- a/src/otx/algo/detection/heads/max_iou_assigner.py +++ b/src/otx/algo/detection/heads/max_iou_assigner.py @@ -11,7 +11,7 @@ from torch import Tensor from otx.algo.detection.heads.iou2d_calculator import BboxOverlaps2D -from otx.algo.detection.structures.structures import AssignResult +from otx.algo.detection.utils.structures import AssignResult if TYPE_CHECKING: from mmengine.structures import InstanceData diff --git a/src/otx/algo/detection/losses/weighted_loss.py b/src/otx/algo/detection/losses/weighted_loss.py new file mode 100644 index 00000000000..69fe1d2696d --- /dev/null +++ b/src/otx/algo/detection/losses/weighted_loss.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Weighted loss from mmdet.""" + +from __future__ import annotations + +import functools +from typing import Callable + +import torch +from torch import Tensor +from torch.nn import functional + + +def reduce_loss(loss: Tensor, reduction: str) -> Tensor: + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = functional._Reduction.get_enum(reduction) # noqa: SLF001 + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + if reduction_enum == 1: + return loss.mean() + if reduction_enum == 2: + return loss.sum() + msg = f"reduction_enum: {reduction_enum} is invalid" + raise ValueError(msg) + + +def weight_reduce_loss( + loss: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: float | None = None, +) -> Tensor: + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor | None): Element-wise weights. + Defaults to None. + reduction (str): Same as built-in losses of PyTorch. + Defaults to 'mean'. + avg_factor (float | None): Average factor when + computing the mean of losses. Defaults to None. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + # if reduction is mean, then average the loss by avg_factor + elif reduction == "mean": + # Avoid causing ZeroDivisionError when avg_factor is 0.0, + # i.e., all labels of an image belong to ignore index. + eps = torch.finfo(torch.float32).eps + loss = loss.sum() / (avg_factor + eps) + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != "none": + msg = "avg_factor can not be used with reduction='sum'" + raise ValueError(msg) + return loss + + +def weighted_loss(loss_func: Callable) -> Callable: + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper( + pred: Tensor, + target: Tensor, + weight: Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + **kwargs, + ) -> Tensor: + """Wrapper for weighted loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): Target bboxes. + weight (Tensor | None): The weight of loss for each + prediction. Defaults to None. + reduction (str): Options are "none", "mean" and "sum". + Defaults to 'mean'. + avg_factor (int | None): Average factor that is used + to average the loss. Defaults to None. + + Returns: + Tensor: Loss tensor. + """ + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + return weight_reduce_loss(loss, weight, reduction, avg_factor) + + return wrapper + + +@weighted_loss +def smooth_l1_loss(pred: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: + """Smooth L1 loss. + + Args: + pred (Tensor): The prediction. + target (Tensor): The learning target of the prediction. + beta (float): The threshold in the piecewise function. + Defaults to 1.0. + + Returns: + Tensor: Calculated loss + """ + if target.numel() == 0: + return pred.sum() * 0 + + diff = torch.abs(pred - target) + return torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) diff --git a/src/otx/algo/detection/structures/__init__.py b/src/otx/algo/detection/utils/__init__.py similarity index 100% rename from src/otx/algo/detection/structures/__init__.py rename to src/otx/algo/detection/utils/__init__.py diff --git a/src/otx/algo/detection/utils/bbox_overlaps.py b/src/otx/algo/detection/utils/bbox_overlaps.py new file mode 100644 index 00000000000..d250e3102c1 --- /dev/null +++ b/src/otx/algo/detection/utils/bbox_overlaps.py @@ -0,0 +1,184 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Overlap between bboxes calculation function.""" + +from __future__ import annotations + +import torch + + +def fp16_clamp(x: torch.Tensor, min_value: int | None = None, max_value: int | None = None) -> torch.Tensor: + """Clamp for cpu float16, tensor fp16 has no clamp implementation.""" + if not x.is_cuda and x.dtype == torch.float16: + return x.float().clamp(min_value, max_value).half() + + return x.clamp(min_value, max_value) + + +def bbox_overlaps( + bboxes1: torch.Tensor, + bboxes2: torch.Tensor, + mode: str = "iou", + is_aligned: bool = False, + eps: float = 1e-6, +) -> torch.Tensor: + """Calculate overlap between two set of bboxes. + + FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 + Note: + Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', + there are some new generated variable when calculating IOU + using bbox_overlaps function: + + 1) is_aligned is False + area1: M x 1 + area2: N x 1 + lt: M x N x 2 + rb: M x N x 2 + wh: M x N x 2 + overlap: M x N x 1 + union: M x N x 1 + ious: M x N x 1 + + Total memory: + S = (9 x N x M + N + M) * 4 Byte, + + When using FP16, we can reduce: + R = (9 x N x M + N + M) * 4 / 2 Byte + R large than (N + M) * 4 * 2 is always true when N and M >= 1. + Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, + N + 1 < 3 * N, when N or M is 1. + + Given M = 40 (ground truth), N = 400000 (three anchor boxes + in per grid, FPN, R-CNNs), + R = 275 MB (one times) + + A special case (dense detection), M = 512 (ground truth), + R = 3516 MB = 3.43 GB + + When the batch size is B, reduce: + B x R + + Therefore, CUDA memory runs out frequently. + + Experiments on GeForce RTX 2080Ti (11019 MiB): + + | dtype | M | N | Use | Real | Ideal | + |:----:|:----:|:----:|:----:|:----:|:----:| + | FP32 | 512 | 400000 | 8020 MiB | -- | -- | + | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | + | FP32 | 40 | 400000 | 1540 MiB | -- | -- | + | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | + + 2) is_aligned is True + area1: N x 1 + area2: N x 1 + lt: N x 2 + rb: N x 2 + wh: N x 2 + overlap: N x 1 + union: N x 1 + ious: N x 1 + + Total memory: + S = 11 x N * 4 Byte + + When using FP16, we can reduce: + R = 11 x N * 4 / 2 Byte + + So do the 'giou' (large than 'iou'). + + Time-wise, FP16 is generally faster than FP32. + + When gpu_assign_thr is not -1, it takes more time on cpu + but not reduce memory. + There, we can reduce half the memory and keep the speed. + + If ``is_aligned`` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union), "iof" (intersection over + foreground) or "giou" (generalized intersection over union). + Default "iou". + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2) + >>> assert overlaps.shape == (3, 3) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) + >>> assert overlaps.shape == (3, ) + + Example: + >>> empty = torch.empty(0, 4) + >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.size(-2) + cols = bboxes2.size(-2) + + if rows * cols == 0: + if is_aligned: + return bboxes1.new((*batch_shape, rows)) + return bboxes1.new((*batch_shape, rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1]) + + if is_aligned: + lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] + rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] + + wh = fp16_clamp(rb - lt, min_value=0) + overlap = wh[..., 0] * wh[..., 1] + + union = area1 + area2 - overlap if mode in ["iou", "giou"] else area1 + if mode == "giou": + enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = torch.max(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = torch.min(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = fp16_clamp(rb - lt, min_value=0) + overlap = wh[..., 0] * wh[..., 1] + + union = area1[..., None] + area2[..., None, :] - overlap if mode in ["iou", "giou"] else area1[..., None] + if mode == "giou": + enclosed_lt = torch.min(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) + enclosed_rb = torch.max(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + if mode in ["iou", "iof"]: + return ious + # calculate gious + enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min_value=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = torch.max(enclose_area, eps) + return ious - (enclose_area - union) / enclose_area diff --git a/src/otx/algo/detection/structures/structures.py b/src/otx/algo/detection/utils/structures.py similarity index 78% rename from src/otx/algo/detection/structures/structures.py rename to src/otx/algo/detection/utils/structures.py index 4bab3c4f79c..5a3b57c011a 100644 --- a/src/otx/algo/detection/structures/structures.py +++ b/src/otx/algo/detection/utils/structures.py @@ -23,24 +23,6 @@ class AssignResult: assigned truth box. labels (Tensor): If specified, for each predicted box indicates the category label of the assigned truth box. - - Example: - >>> # An assign result between 4 predicted boxes and 9 true boxes - >>> # where only two boxes were assigned. - >>> num_gts = 9 - >>> max_overlaps = torch.LongTensor([0, .5, .9, 0]) - >>> gt_inds = torch.LongTensor([-1, 1, 2, 0]) - >>> labels = torch.LongTensor([0, 3, 4, 0]) - >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels) - >>> print(str(self)) # xdoctest: +IGNORE_WANT - - >>> # Force addition of gt labels (when adding gt as proposals) - >>> new_labels = torch.LongTensor([3, 4, 5]) - >>> self.add_gt_(new_labels) - >>> print(str(self)) # xdoctest: +IGNORE_WANT - """ def __init__(self, num_gts: int, gt_inds: Tensor, max_overlaps: Tensor, labels: Tensor) -> None: @@ -105,25 +87,6 @@ class SamplingResult: avg_factor_with_neg (bool): If True, ``avg_factor`` equal to the number of total priors; Otherwise, it is the number of positive priors. Defaults to True. - - Example: - >>> # xdoctest: +IGNORE_WANT - >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA - >>> self = SamplingResult.random(rng=10) - >>> print(f'self = {self}') - self = """ def __init__( From 801394fe7ca19eeb449e11ea79040c16a715bc94 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 16:27:55 +0900 Subject: [PATCH 3/5] Decouple ssd_head --- src/otx/algo/detection/heads/anchor_head.py | 5 +- .../algo/detection/heads/custom_ssd_head.py | 49 +--- .../detection/losses/cross_entropy_loss.py | 272 ++++++++++++++++++ .../detection/mmconfigs/ssd_mobilenetv2.yaml | 2 - src/otx/algo/detection/utils/utils.py | 30 ++ 5 files changed, 319 insertions(+), 39 deletions(-) create mode 100644 src/otx/algo/detection/losses/cross_entropy_loss.py create mode 100644 src/otx/algo/detection/utils/utils.py diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index cba4104dd13..23057e50210 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -10,7 +10,7 @@ import torch from mmdet.models.task_modules.prior_generators import anchor_inside_flags -from mmdet.models.utils import images_to_levels, multi_apply, unmap +from mmdet.models.utils import images_to_levels, unmap from mmdet.registry import MODELS, TASK_UTILS from mmengine.structures import InstanceData from torch import Tensor, nn @@ -18,6 +18,7 @@ from otx.algo.detection.heads.base_head import BaseDenseHead from otx.algo.detection.heads.base_sampler import PseudoSampler from otx.algo.detection.heads.custom_anchor_generator import AnchorGenerator +from otx.algo.detection.utils.utils import multi_apply if TYPE_CHECKING: from mmdet.utils import InstanceList, OptConfigType, OptInstanceList, OptMultiConfig @@ -141,7 +142,7 @@ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]: bbox_pred = self.conv_reg(x) return cls_score, bbox_pred - def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]: + def forward(self, x: tuple[Tensor]) -> tuple: """Forward features from the upstream network. Args: diff --git a/src/otx/algo/detection/heads/custom_ssd_head.py b/src/otx/algo/detection/heads/custom_ssd_head.py index d7e963099da..2861fe874d9 100644 --- a/src/otx/algo/detection/heads/custom_ssd_head.py +++ b/src/otx/algo/detection/heads/custom_ssd_head.py @@ -7,8 +7,6 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.utils import multi_apply -from mmdet.registry import MODELS from torch import Tensor, nn from otx.algo.detection.heads.anchor_head import AnchorHead @@ -16,11 +14,12 @@ from otx.algo.detection.heads.custom_anchor_generator import SSDAnchorGeneratorClustered from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner +from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss from otx.algo.detection.losses.weighted_loss import smooth_l1_loss +from otx.algo.detection.utils.utils import multi_apply if TYPE_CHECKING: - from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList - from mmengine.config import Config + from mmengine.config import ConfigDict, InstanceData # This class and its supporting functions below lightly adapted from the mmdet SSDHead available at: @@ -39,12 +38,6 @@ class SSDHead(AnchorHead): > 0. Defaults to 256. use_depthwise (bool): Whether to use DepthwiseSeparableConv. Defaults to False. - conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config conv layer. Defaults to None. - norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config norm layer. Defaults to None. - act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct - and config activation layer. Defaults to None. anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor generator. bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder. @@ -63,21 +56,18 @@ class SSDHead(AnchorHead): def __init__( self, - anchor_generator: ConfigType, - bbox_coder: ConfigType, - init_cfg: MultiConfig, - act_cfg: ConfigType, + anchor_generator: ConfigDict | dict, + bbox_coder: ConfigDict | dict, + init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict], + act_cfg: ConfigDict | dict, num_classes: int = 80, in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256), stacked_convs: int = 0, feat_channels: int = 256, use_depthwise: bool = False, - conv_cfg: ConfigType | None = None, - norm_cfg: ConfigType | None = None, reg_decoded_bbox: bool = False, - train_cfg: ConfigType | None = None, - test_cfg: ConfigType | None = None, - loss_cls: Config | dict | None = None, + train_cfg: ConfigDict | dict | None = None, + test_cfg: ConfigDict | dict | None = None, ) -> None: super(AnchorHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes @@ -85,9 +75,7 @@ def __init__( self.stacked_convs = stacked_convs self.feat_channels = feat_channels self.use_depthwise = use_depthwise - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg + self.act_cfg = act_cfg # TODO(Jaeguk): act_cfg will be deprecated after implementing export. self.cls_out_channels = num_classes + 1 # add background class anchor_generator.pop("type") @@ -98,14 +86,7 @@ def __init__( # heads but a list of int in SSDHead self.num_base_priors = self.prior_generator.num_base_priors - if loss_cls is None: - loss_cls = { - "type": "CrossEntropyLoss", - "use_sigmoid": False, - "reduction": "none", - "loss_weight": 1.0, - } - self.loss_cls = MODELS.build(loss_cls) + self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0) self._init_layers() @@ -218,9 +199,9 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict[str, list[Tensor]]: """Compute losses of the head. @@ -298,11 +279,9 @@ def _init_layers(self) -> None: self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() - activation_config = self.act_cfg.copy() - activation_config.setdefault("inplace", True) for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors): if self.use_depthwise: - activation_layer = MODELS.build(activation_config) + activation_layer = nn.ReLU(inplace=True) self.reg_convs.append( nn.Sequential( diff --git a/src/otx/algo/detection/losses/cross_entropy_loss.py b/src/otx/algo/detection/losses/cross_entropy_loss.py new file mode 100644 index 00000000000..81c3be1a1b1 --- /dev/null +++ b/src/otx/algo/detection/losses/cross_entropy_loss.py @@ -0,0 +1,272 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Base Cross Entropy Loss implementation from mmdet.""" + +from __future__ import annotations + +import torch +from torch import nn + +from otx.algo.detection.losses.weighted_loss import weight_reduce_loss + + +# All of the methods and classes below come from mmdet, and are slightly modified. +# https://github.com/open-mmlab/mmdetection/blob/ecac3a77becc63f23d9f6980b2a36f86acd00a8a/mmdet/models/losses/cross_entropy_loss.py +def cross_entropy( + pred: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + class_weight: list[float] | None = None, + ignore_index: int = -100, + avg_non_ignore: bool = False, +) -> torch.Tensor: + """Calculate the CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. + Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss + """ + loss = nn.functional.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + return weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + +def _expand_onehot_labels( + labels: torch.Tensor, + label_weights: torch.Tensor, + label_channels: int, + ignore_index: int, +) -> tuple[torch.Tensor, ...]: + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_full((labels.size(0), label_channels), 0) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask & (labels < label_channels), as_tuple=False) + + if inds.numel() > 0: + bin_labels[inds, labels[inds]] = 1 + + valid_mask = valid_mask.view(-1, 1).expand(labels.size(0), label_channels).float() + bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + reduction: str = "mean", + avg_factor: int | None = None, + class_weight: list[float] | None = None, + ignore_index: int = -100, + avg_non_ignore: bool = False, +) -> torch.Tensor: + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1) or (N, ). + When the shape of pred is (N, 1), label will be expanded to + one-hot format, and when the shape of pred is (N, ), label + will not be expanded to one-hot format. + label (torch.Tensor): The learning label of the prediction, + with shape (N, ). + weight (torch.Tensor, None): Sample-wise loss weight. + reduction (str): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. + Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + + Returns: + torch.Tensor: The calculated loss. + """ + if pred.dim() != label.dim(): + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.size(-1), ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + # The inplace writing method will have a mismatched broadcast + # shape error if the weight and valid_mask dimensions + # are inconsistent such as (B,N,1) and (B,N,C). + weight = weight * valid_mask if weight is not None else valid_mask + + # average loss over non-ignored elements + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = valid_mask.sum().item() + + # weighted element-wise losses + weight = weight.float() + loss = nn.functional.binary_cross_entropy_with_logits( + pred, + label.float(), + pos_weight=class_weight, + reduction="none", + ) + # do the reduction for the weighted loss + return weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + +def mask_cross_entropy( + pred: torch.Tensor, + target: torch.Tensor, + label: torch.Tensor, + class_weight: list[float] | None = None, + **kwargs, # noqa: ARG001 +) -> torch.Tensor: + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C, *), C is the + number of classes. The trailing * indicates arbitrary shape. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + class_weight (list[float], None): The weight for each class. + + Returns: + torch.Tensor: The calculated loss + + Example: + >>> N, C = 3, 11 + >>> H, W = 2, 2 + >>> pred = torch.randn(N, C, H, W) * 1000 + >>> target = torch.rand(N, H, W) + >>> label = torch.randint(0, C, size=(N,)) + >>> reduction = 'mean' + >>> avg_factor = None + >>> class_weights = None + >>> loss = mask_cross_entropy(pred, target, label, reduction, + >>> avg_factor, class_weights) + >>> assert loss.shape == (1,) + """ + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return nn.functional.binary_cross_entropy_with_logits( + pred_slice, + target, + weight=class_weight, + reduction="mean", + )[None] + + +class CrossEntropyLoss(nn.Module): + """Base Cross Entropy Loss implementation from mmdet.""" + + def __init__( + self, + use_sigmoid: bool = False, + use_mask: bool = False, + reduction: str = "mean", + class_weight: list[float] | None = None, + loss_weight: float = 1.0, + avg_non_ignore: bool = False, + ): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float): Weight of the loss. Defaults to 1.0. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + """ + super().__init__() + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + self.avg_non_ignore = avg_non_ignore + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy # type: ignore[assignment] + else: + self.cls_criterion = cross_entropy + + def extra_repr(self) -> str: + """Extra repr.""" + return f"avg_non_ignore={self.avg_non_ignore}" + + def forward( + self, + cls_score: torch.Tensor, + label: torch.Tensor, + weight: torch.Tensor | None = None, + avg_factor: int | None = None, + reduction_override: str | None = None, + ignore_index: int = -100, + **kwargs, + ) -> torch.Tensor: + """Forward function. + + Args: + cls_score (torch.Tensor): The prediction. + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, None): Sample-wise loss weight. + avg_factor (int, None): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, None): The method used to reduce the + loss. Options are "none", "mean" and "sum". + ignore_index (int): The label index to be ignored. + Default: -100. + + Returns: + torch.Tensor: The calculated loss. + """ + reduction = reduction_override if reduction_override else self.reduction + + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight, device=cls_score.device) + else: + class_weight = None + return self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + ignore_index=ignore_index, + avg_non_ignore=self.avg_non_ignore, + **kwargs, + ) diff --git a/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml b/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml index 93b40f0df85..a6685ee4cac 100644 --- a/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml +++ b/src/otx/algo/detection/mmconfigs/ssd_mobilenetv2.yaml @@ -50,8 +50,6 @@ bbox_head: - 96 - 320 use_depthwise: true - norm_cfg: - type: BN act_cfg: type: ReLU init_cfg: diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py new file mode 100644 index 00000000000..c68150f8c3b --- /dev/null +++ b/src/otx/algo/detection/utils/utils.py @@ -0,0 +1,30 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Utils for otx detection algo.""" + +from __future__ import annotations + +from functools import partial +from typing import Callable + + +def multi_apply(func: Callable, *args, **kwargs) -> tuple: + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) # type: ignore[call-overload] + return tuple(map(list, zip(*map_results))) From feddfbde1c9000db13fe6150fb00e787eee3034f Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 17:00:43 +0900 Subject: [PATCH 4/5] Decoupling anchor head --- src/otx/algo/detection/heads/anchor_head.py | 20 +++--- .../algo/detection/heads/custom_ssd_head.py | 2 +- src/otx/algo/detection/utils/utils.py | 64 +++++++++++++++++++ 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/src/otx/algo/detection/heads/anchor_head.py b/src/otx/algo/detection/heads/anchor_head.py index 23057e50210..1f56adb373c 100644 --- a/src/otx/algo/detection/heads/anchor_head.py +++ b/src/otx/algo/detection/heads/anchor_head.py @@ -9,8 +9,6 @@ from typing import TYPE_CHECKING import torch -from mmdet.models.task_modules.prior_generators import anchor_inside_flags -from mmdet.models.utils import images_to_levels, unmap from mmdet.registry import MODELS, TASK_UTILS from mmengine.structures import InstanceData from torch import Tensor, nn @@ -18,10 +16,10 @@ from otx.algo.detection.heads.base_head import BaseDenseHead from otx.algo.detection.heads.base_sampler import PseudoSampler from otx.algo.detection.heads.custom_anchor_generator import AnchorGenerator -from otx.algo.detection.utils.utils import multi_apply +from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap if TYPE_CHECKING: - from mmdet.utils import InstanceList, OptConfigType, OptInstanceList, OptMultiConfig + from mmengine import ConfigDict # This class and its supporting functions below lightly adapted from the mmdet AnchorHead available at: @@ -56,11 +54,11 @@ def __init__( bbox_coder: dict, loss_cls: dict, loss_bbox: dict, + train_cfg: ConfigDict | dict, feat_channels: int = 256, reg_decoded_bbox: bool = False, - train_cfg: OptConfigType = None, - test_cfg: OptConfigType = None, - init_cfg: OptMultiConfig = None, + test_cfg: ConfigDict | dict | None = None, + init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict] | None = None, ) -> None: super().__init__(init_cfg=init_cfg) self.in_channels = in_channels @@ -302,9 +300,9 @@ def get_targets( self, anchor_list: list[list[Tensor]], valid_flag_list: list[list[Tensor]], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, unmap_outputs: bool = True, ) -> tuple: """Compute regression and classification targets for anchors in multiple images. @@ -461,9 +459,9 @@ def loss_by_feat( self, cls_scores: list[Tensor], bbox_preds: list[Tensor], - batch_gt_instances: InstanceList, + batch_gt_instances: list[InstanceData], batch_img_metas: list[dict], - batch_gt_instances_ignore: OptInstanceList = None, + batch_gt_instances_ignore: list[InstanceData] | None = None, ) -> dict: """Calculate the loss based on the features extracted by the detection head. diff --git a/src/otx/algo/detection/heads/custom_ssd_head.py b/src/otx/algo/detection/heads/custom_ssd_head.py index 2861fe874d9..fe0ccc4bf1a 100644 --- a/src/otx/algo/detection/heads/custom_ssd_head.py +++ b/src/otx/algo/detection/heads/custom_ssd_head.py @@ -60,13 +60,13 @@ def __init__( bbox_coder: ConfigDict | dict, init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict], act_cfg: ConfigDict | dict, + train_cfg: ConfigDict | dict, num_classes: int = 80, in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256), stacked_convs: int = 0, feat_channels: int = 256, use_depthwise: bool = False, reg_decoded_bbox: bool = False, - train_cfg: ConfigDict | dict | None = None, test_cfg: ConfigDict | dict | None = None, ) -> None: super(AnchorHead, self).__init__(init_cfg=init_cfg) diff --git a/src/otx/algo/detection/utils/utils.py b/src/otx/algo/detection/utils/utils.py index c68150f8c3b..3abcb1178dc 100644 --- a/src/otx/algo/detection/utils/utils.py +++ b/src/otx/algo/detection/utils/utils.py @@ -7,6 +7,9 @@ from functools import partial from typing import Callable +import torch +from torch import Tensor + def multi_apply(func: Callable, *args, **kwargs) -> tuple: """Apply function to a list of arguments. @@ -28,3 +31,64 @@ def multi_apply(func: Callable, *args, **kwargs) -> tuple: pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) # type: ignore[call-overload] return tuple(map(list, zip(*map_results))) + + +def anchor_inside_flags( + flat_anchors: Tensor, + valid_flags: Tensor, + img_shape: tuple[int, ...], + allowed_border: int = 0, +) -> Tensor: + """Check whether the anchors are inside the border. + + Args: + flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). + valid_flags (torch.Tensor): An existing valid flags of anchors. + img_shape (tuple(int)): Shape of current image. + allowed_border (int): The border to allow the valid anchor. + Defaults to 0. + + Returns: + torch.Tensor: Flags indicating whether the anchors are inside a \ + valid range. + """ + img_h, img_w = img_shape[:2] + if allowed_border >= 0: + inside_flags = ( + valid_flags + & (flat_anchors[:, 0] >= -allowed_border) + & (flat_anchors[:, 1] >= -allowed_border) + & (flat_anchors[:, 2] < img_w + allowed_border) + & (flat_anchors[:, 3] < img_h + allowed_border) + ) + else: + inside_flags = valid_flags + return inside_flags + + +def images_to_levels(target: list[Tensor], num_levels: list[int]) -> list[Tensor]: + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + stacked_target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + # level_targets.append(target[:, start:end].squeeze(0)) + level_targets.append(stacked_target[:, start:end]) + start = end + return level_targets + + +def unmap(data: Tensor, count: int, inds: Tensor, fill: int = 0) -> Tensor: + """Unmap a subset of item (data) back to the original set of items (of size count).""" + if data.dim() == 1: + ret = data.new_full((count,), fill) + ret[inds.type(torch.bool)] = data + else: + new_size = (count,) + data.size()[1:] + ret = data.new_full(new_size, fill) + ret[inds.type(torch.bool), :] = data + return ret From d532ee69d030c45d3966d1e4e3c3c8199f8f6583 Mon Sep 17 00:00:00 2001 From: jaegukhyun Date: Thu, 11 Apr 2024 17:09:59 +0900 Subject: [PATCH 5/5] Fix unit tests --- .../algo/detection/heads/test_custom_ssd_head.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/unit/algo/detection/heads/test_custom_ssd_head.py b/tests/unit/algo/detection/heads/test_custom_ssd_head.py index 7d95e3f4591..e2f23e6a019 100644 --- a/tests/unit/algo/detection/heads/test_custom_ssd_head.py +++ b/tests/unit/algo/detection/heads/test_custom_ssd_head.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """Test of CustomSSDHead.""" -from mmdet.models.losses.cross_entropy_loss import CrossEntropyLoss from otx.algo.detection.heads.custom_ssd_head import SSDHead +from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss class TestSSDHead: @@ -25,6 +25,20 @@ def test_init(self, mocker) -> None: "target_means": [0.0, 0.0, 0.0, 0.0], "target_stds": [0.1, 0.1, 0.1, 0.1], }, + train_cfg={ + "assigner": { + "type": "MaxIoUAssigner", + "pos_iou_thr": 0.4, + "neg_iou_thr": 0.4, + }, + "smoothl1_beta": 1.0, + "allowed_border": -1, + "pos_weight": -1, + "neg_pos_ratio": 3, + "debug": False, + "use_giou": False, + "use_focal": False, + }, ) assert isinstance(self.head.loss_cls, CrossEntropyLoss)