Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoupling mmdet structures Part 1. #3301

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 14 additions & 18 deletions src/otx/algo/detection/heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.task_modules.prior_generators import anchor_inside_flags
from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import BaseBoxes, cat_boxes, get_box_tensor
from mmengine.structures import InstanceData
from torch import Tensor, nn

from otx.algo.detection.heads.base_head import BaseDenseHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.custom_anchor_generator import AnchorGenerator
from otx.algo.detection.utils.utils import anchor_inside_flags, images_to_levels, multi_apply, unmap

if TYPE_CHECKING:
from mmdet.utils import InstanceList, OptConfigType, OptInstanceList, OptMultiConfig
from mmengine import ConfigDict


# This class and its supporting functions below lightly adapted from the mmdet AnchorHead available at:
Expand Down Expand Up @@ -56,11 +54,11 @@ def __init__(
bbox_coder: dict,
loss_cls: dict,
loss_bbox: dict,
train_cfg: ConfigDict | dict,
feat_channels: int = 256,
reg_decoded_bbox: bool = False,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None,
test_cfg: ConfigDict | dict | None = None,
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict] | None = None,
) -> None:
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
Expand Down Expand Up @@ -142,7 +140,7 @@ def forward_single(self, x: Tensor) -> tuple[Tensor, Tensor]:
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred

def forward(self, x: tuple[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
def forward(self, x: tuple[Tensor]) -> tuple:
"""Forward features from the upstream network.

Args:
Expand Down Expand Up @@ -199,7 +197,7 @@ def get_anchors(

def _get_targets_single(
self,
flat_anchors: Tensor | BaseBoxes,
flat_anchors: Tensor,
valid_flags: Tensor,
gt_instances: InstanceData,
img_meta: dict,
Expand All @@ -209,7 +207,7 @@ def _get_targets_single(
"""Compute regression and classification targets for anchors in a single image.

Args:
flat_anchors (Tensor or :obj:`BaseBoxes`): Multi-level anchors
flat_anchors (Tensor): Multi-level anchors
of the image, which are concatenated into a single tensor
or box type of shape (num_anchors, 4)
valid_flags (Tensor): Multi level valid flags of the image,
Expand Down Expand Up @@ -277,7 +275,6 @@ def _get_targets_single(
pos_bbox_targets = self.bbox_coder.encode(sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
else:
pos_bbox_targets = sampling_result.pos_gt_bboxes
pos_bbox_targets = get_box_tensor(pos_bbox_targets)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0

Expand All @@ -303,9 +300,9 @@ def get_targets(
self,
anchor_list: list[list[Tensor]],
valid_flag_list: list[list[Tensor]],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
unmap_outputs: bool = True,
) -> tuple:
"""Compute regression and classification targets for anchors in multiple images.
Expand Down Expand Up @@ -364,7 +361,7 @@ def get_targets(
concat_anchor_list = []
concat_valid_flag_list = []
for i in range(num_imgs):
concat_anchor_list.append(cat_boxes(anchor_list[i]))
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))

# compute targets for each image
Expand Down Expand Up @@ -455,17 +452,16 @@ def loss_by_feat_single(
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, anchors.size(-1))
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
bbox_pred = get_box_tensor(bbox_pred)
loss_bbox = self.loss_bbox(bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
return loss_cls, loss_bbox

def loss_by_feat(
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict:
"""Calculate the loss based on the features extracted by the detection head.

Expand Down Expand Up @@ -504,7 +500,7 @@ def loss_by_feat(
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
# concat all level anchors and flags to a single tensor
concat_anchor_list = [cat_boxes(anchor) for anchor in anchor_list]
concat_anchor_list = [torch.cat(anchor) for anchor in anchor_list]
all_anchor_list = images_to_levels(concat_anchor_list, num_level_anchors)

losses_cls, losses_bbox = multi_apply(
Expand Down
34 changes: 4 additions & 30 deletions src/otx/algo/detection/heads/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from abc import ABCMeta, abstractmethod

import torch
from mmdet.models.task_modules.assigners import AssignResult
from mmdet.models.task_modules.samplers.sampling_result import SamplingResult
from mmdet.structures.bbox import BaseBoxes, cat_boxes
from mmengine.structures import InstanceData

from otx.algo.detection.utils.structures import AssignResult, SamplingResult


class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers.
Expand Down Expand Up @@ -72,26 +71,6 @@ def sample(

Returns:
:obj:`SamplingResult`: Sampling result.

Example:
>>> from mmengine.structures import InstanceData
>>> from mmdet.models.task_modules.samplers import RandomSampler,
>>> from mmdet.models.task_modules.assigners import AssignResult
>>> from mmdet.models.task_modules.samplers.
... sampling_result import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> pred_instances = InstanceData()
>>> pred_instances.priors = random_boxes(assign_result.num_preds,
... rng=rng)
>>> gt_instances = InstanceData()
>>> gt_instances.bboxes = random_boxes(assign_result.num_gts,
... rng=rng)
>>> gt_instances.labels = torch.randint(
... 0, 5, (assign_result.num_gts,), dtype=torch.long)
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, pred_instances, gt_instances)
"""
gt_bboxes = gt_instances.bboxes
priors = pred_instances.priors
Expand All @@ -101,13 +80,8 @@ def sample(

gt_flags = priors.new_zeros((priors.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
# When `gt_bboxes` and `priors` are all box type, convert
# `gt_bboxes` type to `priors` type.
if isinstance(gt_bboxes, BaseBoxes) and isinstance(priors, BaseBoxes):
gt_bboxes_ = gt_bboxes.convert_to(type(priors))
else:
gt_bboxes_ = gt_bboxes
priors = cat_boxes([gt_bboxes_, priors], dim=0)
gt_bboxes_ = gt_bboxes
priors = torch.cat([gt_bboxes_, priors], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = priors.new_ones(gt_bboxes_.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
Expand Down
51 changes: 15 additions & 36 deletions src/otx/algo/detection/heads/custom_ssd_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,19 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.losses import smooth_l1_loss
from mmdet.models.utils import multi_apply
from mmdet.registry import MODELS
from torch import Tensor, nn

from otx.algo.detection.heads.anchor_head import AnchorHead
from otx.algo.detection.heads.base_sampler import PseudoSampler
from otx.algo.detection.heads.custom_anchor_generator import SSDAnchorGeneratorClustered
from otx.algo.detection.heads.delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from otx.algo.detection.heads.max_iou_assigner import MaxIoUAssigner
from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss
from otx.algo.detection.losses.weighted_loss import smooth_l1_loss
from otx.algo.detection.utils.utils import multi_apply

if TYPE_CHECKING:
from mmdet.utils import ConfigType, InstanceList, MultiConfig, OptInstanceList
from mmengine.config import Config
from mmengine.config import ConfigDict, InstanceData


# This class and its supporting functions below lightly adapted from the mmdet SSDHead available at:
Expand All @@ -39,12 +38,6 @@ class SSDHead(AnchorHead):
> 0. Defaults to 256.
use_depthwise (bool): Whether to use DepthwiseSeparableConv.
Defaults to False.
conv_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config conv layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config norm layer. Defaults to None.
act_cfg (:obj:`ConfigDict` or dict, Optional): Dictionary to construct
and config activation layer. Defaults to None.
anchor_generator (:obj:`ConfigDict` or dict): Config dict for anchor
generator.
bbox_coder (:obj:`ConfigDict` or dict): Config of bounding box coder.
Expand All @@ -63,31 +56,26 @@ class SSDHead(AnchorHead):

def __init__(
self,
anchor_generator: ConfigType,
bbox_coder: ConfigType,
init_cfg: MultiConfig,
act_cfg: ConfigType,
anchor_generator: ConfigDict | dict,
bbox_coder: ConfigDict | dict,
init_cfg: ConfigDict | dict | list[ConfigDict] | list[dict],
act_cfg: ConfigDict | dict,
train_cfg: ConfigDict | dict,
num_classes: int = 80,
in_channels: tuple[int, ...] = (512, 1024, 512, 256, 256, 256),
stacked_convs: int = 0,
feat_channels: int = 256,
use_depthwise: bool = False,
conv_cfg: ConfigType | None = None,
norm_cfg: ConfigType | None = None,
reg_decoded_bbox: bool = False,
train_cfg: ConfigType | None = None,
test_cfg: ConfigType | None = None,
loss_cls: Config | dict | None = None,
test_cfg: ConfigDict | dict | None = None,
) -> None:
super(AnchorHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.in_channels = in_channels
self.stacked_convs = stacked_convs
self.feat_channels = feat_channels
self.use_depthwise = use_depthwise
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.act_cfg = act_cfg # TODO(Jaeguk): act_cfg will be deprecated after implementing export.

self.cls_out_channels = num_classes + 1 # add background class
anchor_generator.pop("type")
Expand All @@ -98,14 +86,7 @@ def __init__(
# heads but a list of int in SSDHead
self.num_base_priors = self.prior_generator.num_base_priors

if loss_cls is None:
loss_cls = {
"type": "CrossEntropyLoss",
"use_sigmoid": False,
"reduction": "none",
"loss_weight": 1.0,
}
self.loss_cls = MODELS.build(loss_cls)
self.loss_cls = CrossEntropyLoss(use_sigmoid=False, reduction="none", loss_weight=1.0)

self._init_layers()

Expand Down Expand Up @@ -218,9 +199,9 @@ def loss_by_feat(
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict[str, list[Tensor]]:
"""Compute losses of the head.

Expand Down Expand Up @@ -298,11 +279,9 @@ def _init_layers(self) -> None:
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()

activation_config = self.act_cfg.copy()
activation_config.setdefault("inplace", True)
for in_channel, num_base_priors in zip(self.in_channels, self.num_base_priors):
if self.use_depthwise:
activation_layer = MODELS.build(activation_config)
activation_layer = nn.ReLU(inplace=True)

self.reg_convs.append(
nn.Sequential(
Expand Down
25 changes: 9 additions & 16 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from mmdet.structures.bbox import BaseBoxes, HorizontalBoxes, get_box_tensor
from torch import Tensor


Expand Down Expand Up @@ -44,40 +43,39 @@ def __init__(
ctr_clamp: int = 32,
) -> None:
self.encode_size = encode_size
# TODO(Jaeguk): use_box_type should be deprecated.
self.use_box_type = use_box_type
self.means = target_means
self.stds = target_stds
self.clip_border = clip_border
self.add_ctr_clamp = add_ctr_clamp
self.ctr_clamp = ctr_clamp

def encode(self, bboxes: Tensor | BaseBoxes, gt_bboxes: Tensor | BaseBoxes) -> Tensor:
def encode(self, bboxes: Tensor, gt_bboxes: Tensor) -> Tensor:
"""Get box regression transformation deltas that can be used to transform the bboxes into the gt_bboxes.

Args:
bboxes (torch.Tensor or :obj:`BaseBoxes`): Source boxes,
bboxes (torch.Tensor): Source boxes,
e.g., object proposals.
gt_bboxes (torch.Tensor or :obj:`BaseBoxes`): Target of the
gt_bboxes (torch.Tensor): Target of the
transformation, e.g., ground-truth boxes.

Returns:
torch.Tensor: Box transformation deltas
"""
bboxes = get_box_tensor(bboxes)
gt_bboxes = get_box_tensor(gt_bboxes)
return bbox2delta(bboxes, gt_bboxes, self.means, self.stds)

def decode(
self,
bboxes: Tensor | BaseBoxes,
bboxes: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
wh_ratio_clip: float = 16 / 1000,
) -> Tensor | BaseBoxes:
) -> Tensor:
"""Apply transformation `pred_bboxes` to `boxes`.

Args:
bboxes (torch.Tensor or :obj:`BaseBoxes`): Basic boxes. Shape
bboxes (torch.Tensor): Basic boxes. Shape
(B, N, 4) or (N, 4)
pred_bboxes (Tensor): Encoded offsets with respect to each roi.
Has shape (B, N, num_classes * 4) or (B, N, 4) or
Expand All @@ -92,10 +90,9 @@ def decode(
width and height.

Returns:
Union[torch.Tensor, :obj:`BaseBoxes`]: Decoded boxes.
torch.Tensor: Decoded boxes.
"""
bboxes = get_box_tensor(bboxes)
decoded_bboxes = delta2bbox(
return delta2bbox(
bboxes,
pred_bboxes,
self.means,
Expand All @@ -107,10 +104,6 @@ def decode(
self.ctr_clamp,
)

if self.use_box_type:
decoded_bboxes = HorizontalBoxes(decoded_bboxes)
return decoded_bboxes


def bbox2delta(
proposals: Tensor,
Expand Down
Loading
Loading