Skip to content

Commit

Permalink
Decoupling mmdet structures part2 (#3315)
Browse files Browse the repository at this point in the history
* Decouple anchor generator

* Decouple base head

* Decouple SSD class

* Fix pre-commit
  • Loading branch information
jaegukhyun authored Apr 16, 2024
1 parent 5fe4088 commit e45e8d3
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 61 deletions.
56 changes: 32 additions & 24 deletions src/otx/algo/detection/heads/base_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,15 @@
from typing import TYPE_CHECKING

import torch
from mmdet.models.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances
from mmdet.structures.bbox import cat_boxes, get_box_tensor, get_box_wh, scale_boxes
from mmcv.ops import batched_nms
from mmengine.model import constant_init
from mmengine.structures import InstanceData
from torch import Tensor, nn

from otx.algo.detection.ops.nms import batched_nms
from otx.algo.detection.utils.utils import filter_scores_and_topk, select_single_mlvl, unpack_gt_instances

if TYPE_CHECKING:
from mmdet.structures import SampleList
from mmdet.utils import InstanceList, OptInstanceList, OptMultiConfig
from mmengine.config import ConfigDict
from mmengine import ConfigDict


# This class and its supporting functions below lightly adapted from the mmdet BaseDenseHead available at:
Expand Down Expand Up @@ -63,7 +60,7 @@ class BaseDenseHead(nn.Module):
loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat()
"""

def __init__(self, init_cfg: OptMultiConfig = None) -> None:
def __init__(self, init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] | None = None) -> None:
super().__init__()

self._is_init = False
Expand All @@ -83,7 +80,7 @@ def init_weights(self) -> None:
if hasattr(m, "conv_offset"):
constant_init(m.conv_offset, 0)

def get_positive_infos(self) -> InstanceList:
def get_positive_infos(self) -> list[InstanceData] | None:
"""Get positive information from sampling results.
Returns:
Expand All @@ -106,7 +103,7 @@ def get_positive_infos(self) -> InstanceList:
positive_infos.append(pos_info)
return positive_infos

def loss(self, x: tuple[Tensor], batch_data_samples: SampleList) -> dict:
def loss(self, x: tuple[Tensor], batch_data_samples: list[InstanceData]) -> dict:
"""Perform forward propagation and loss calculation of the detection head.
Args:
Expand All @@ -132,18 +129,18 @@ def loss_by_feat(
self,
cls_scores: list[Tensor],
bbox_preds: list[Tensor],
batch_gt_instances: InstanceList,
batch_gt_instances: list[InstanceData],
batch_img_metas: list[dict],
batch_gt_instances_ignore: OptInstanceList = None,
batch_gt_instances_ignore: list[InstanceData] | None = None,
) -> dict:
"""Calculate the loss based on the features extracted by the detection head."""

def loss_and_predict(
self,
x: tuple[Tensor],
batch_data_samples: SampleList,
batch_data_samples: list[InstanceData],
proposal_cfg: ConfigDict | None = None,
) -> tuple[dict, InstanceList]:
) -> tuple[dict, list[InstanceData]]:
"""Perform forward propagation of the head, then calculate loss and predictions.
Args:
Expand Down Expand Up @@ -173,7 +170,12 @@ def loss_and_predict(
predictions = self.predict_by_feat(cls_scores, bbox_preds, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
return losses, predictions

def predict(self, x: tuple[Tensor], batch_data_samples: SampleList, rescale: bool = False) -> InstanceList:
def predict(
self,
x: tuple[Tensor],
batch_data_samples: list[InstanceData],
rescale: bool = False,
) -> list[InstanceData]:
"""Perform forward propagation of the detection head and predict detection results.
Args:
Expand Down Expand Up @@ -204,7 +206,7 @@ def predict_by_feat(
cfg: ConfigDict | None = None,
rescale: bool = False,
with_nms: bool = True,
) -> InstanceList:
) -> list[InstanceData]:
"""Transform a batch of output features extracted from the head into bbox results.
Note: When score_factors is not None, the cls_scores are
Expand Down Expand Up @@ -242,8 +244,6 @@ def predict_by_feat(
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
with_score_factors = score_factors is not None

num_levels = len(cls_scores)

featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
Expand All @@ -259,7 +259,7 @@ def predict_by_feat(
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(bbox_preds, img_id, detach=True)
if with_score_factors:
if score_factors is not None:
score_factor_list = select_single_mlvl(score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]
Expand Down Expand Up @@ -370,8 +370,13 @@ def _predict_by_feat_single(
# `nms_pre` than before.
score_thr = cfg.get("score_thr", 0)

results = filter_scores_and_topk(scores, score_thr, nms_pre, {"bbox_pred": bbox_pred, "priors": priors})
scores, labels, keep_idxs, filtered_results = results
filtered_results: dict
scores, labels, keep_idxs, filtered_results = filter_scores_and_topk( # type: ignore[assignment]
scores,
score_thr,
nms_pre,
{"bbox_pred": bbox_pred, "priors": priors},
)

bbox_pred = filtered_results["bbox_pred"] # noqa: PLW2901
priors = filtered_results["priors"] # noqa: PLW2901
Expand All @@ -388,7 +393,7 @@ def _predict_by_feat_single(
mlvl_score_factors.append(score_factor)

bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_priors)
priors = torch.cat(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)

results = InstanceData()
Expand Down Expand Up @@ -438,21 +443,24 @@ def _bbox_post_process(
"""
if rescale:
scale_factor = [1 / s for s in img_meta["scale_factor"]]
results.bboxes = scale_boxes(results.bboxes, scale_factor)
results.bboxes = results.bboxes * results.bboxes.new_tensor(scale_factor).repeat(
(1, int(results.bboxes.size(-1) / 2)),
)

if hasattr(results, "score_factors"):
score_factors = results.pop("score_factors")
results.scores = results.scores * score_factors

# filter small size bboxes
if cfg.get("min_bbox_size", -1) >= 0:
w, h = get_box_wh(results.bboxes)
w = results.bboxes[:, 2] - results.bboxes[:, 0]
h = results.bboxes[:, 3] - results.bboxes[:, 1]
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]

if with_nms and results.bboxes.numel() > 0:
bboxes = get_box_tensor(results.bboxes)
bboxes = results.bboxes
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, results.labels, cfg.nms)
results = results[keep_idxs]
# some nms would reweight the score, such as softnms
Expand Down
11 changes: 1 addition & 10 deletions src/otx/algo/detection/heads/custom_anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
import torch
from mmdet.registry import TASK_UTILS
from mmdet.structures.bbox import HorizontalBoxes
from torch.nn.modules.utils import _pair


Expand Down Expand Up @@ -44,8 +43,6 @@ class AnchorGenerator:
float is given, they will be used to shift the centers of anchors.
center_offset (float): The offset of center in proportion to anchors'
width and height. By default it is 0 in V2.0.
use_box_type (bool): Whether to warp anchors with the box type data
structure. Defaults to False.
Examples:
>>> from mmdet.models.task_modules.
Expand Down Expand Up @@ -78,7 +75,6 @@ def __init__(
scales_per_octave: int | None = None,
centers: list[tuple[float, float]] | None = None,
center_offset: float = 0.0,
use_box_type: bool = False,
) -> None:
# check center and center_offset
if center_offset != 0 and centers is None:
Expand Down Expand Up @@ -112,7 +108,6 @@ def __init__(
self.centers = centers
self.center_offset = center_offset
self.base_anchors = self.gen_base_anchors()
self.use_box_type = use_box_type

@property
def num_base_anchors(self) -> list[int]:
Expand Down Expand Up @@ -278,12 +273,9 @@ def single_level_grid_priors(
# shifted anchors (K, A, 4), reshape to (K*A, 4)

all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
all_anchors = all_anchors.view(-1, 4)
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
if self.use_box_type:
all_anchors = HorizontalBoxes(all_anchors)
return all_anchors
return all_anchors.view(-1, 4)

def sparse_priors(
self,
Expand Down Expand Up @@ -506,7 +498,6 @@ def __init__(

self.center_offset = 0
self.gen_base_anchors()
self.use_box_type = False

def gen_base_anchors(self) -> None: # type: ignore[override]
"""Generate base anchor for SSD."""
Expand Down
61 changes: 35 additions & 26 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
if TYPE_CHECKING:
import torch
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from mmdet.structures import DetDataSample, OptSampleList, SampleList
from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from omegaconf import DictConfig
from torch import Tensor, device

Expand All @@ -51,12 +51,12 @@ class SingleStageDetector(nn.Module):

def __init__(
self,
backbone: ConfigType,
bbox_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None,
backbone: ConfigDict | dict,
bbox_head: ConfigDict | dict,
train_cfg: ConfigDict | dict | None = None,
test_cfg: ConfigDict | dict | None = None,
data_preprocessor: ConfigDict | dict | None = None,
init_cfg: ConfigDict | list[ConfigDict] | dict | list[dict] = None,
) -> None:
super().__init__()
self._is_init = False
Expand Down Expand Up @@ -156,17 +156,17 @@ def init_weights(self) -> None:
def forward(
self,
inputs: torch.Tensor,
data_samples: OptSampleList = None,
data_samples: list[InstanceData],
mode: str = "tensor",
) -> dict[str, torch.Tensor] | list[DetDataSample] | tuple[torch.Tensor] | torch.Tensor:
) -> dict[str, torch.Tensor] | list[InstanceData] | tuple[torch.Tensor] | torch.Tensor:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`.
processed to a list of :obj:`InstanceData`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Expand All @@ -176,7 +176,7 @@ def forward(
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (list[:obj:`DetDataSample`], optional): A batch of
data_samples (list[:obj:`InstanceData`], optional): A batch of
data samples that contain annotations and predictions.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Expand All @@ -185,7 +185,7 @@ def forward(
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`.
- If ``mode="predict"``, return a list of :obj:`InstanceData`.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == "loss":
Expand All @@ -201,14 +201,14 @@ def forward(
def loss(
self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
batch_data_samples: list[InstanceData],
) -> dict | list:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
batch_data_samples (list[:obj:`InstanceData`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Expand All @@ -218,20 +218,25 @@ def loss(
x = self.extract_feat(batch_inputs)
return self.bbox_head.loss(x, batch_data_samples)

def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList:
def predict(
self,
batch_inputs: Tensor,
batch_data_samples: list[InstanceData],
rescale: bool = True,
) -> list[InstanceData]:
"""Predict results from a batch of inputs and data samples with post-processing.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
batch_data_samples (List[:obj:`InstanceData`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
list[:obj:`InstanceData`]: Detection results of the
input images. Each InstanceData usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
Expand All @@ -249,13 +254,13 @@ def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale:
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None,
batch_data_samples: list[InstanceData] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
"""Network forward process.
Args:
batch_inputs (Tensor): Inputs with shape (N, C, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
batch_data_samples (list[:obj:`InstanceData`]): Each item contains
the meta information of each image and corresponding
annotations.
Expand All @@ -280,18 +285,22 @@ def extract_feat(self, batch_inputs: Tensor) -> tuple[Tensor]:
x = self.neck(x)
return x

def add_pred_to_datasample(self, data_samples: SampleList, results_list: InstanceList) -> SampleList:
"""Add predictions to `DetDataSample`.
def add_pred_to_datasample(
self,
data_samples: list[InstanceData],
results_list: list[InstanceData],
) -> list[InstanceData]:
"""Add predictions to `InstanceData`.
Args:
data_samples (list[:obj:`DetDataSample`], optional): A batch of
data_samples (list[:obj:`InstanceData`], optional): A batch of
data samples that contain annotations and predictions.
results_list (list[:obj:`InstanceData`]): Detection results of
each image.
Returns:
list[:obj:`DetDataSample`]: Detection results of the
input images. Each DetDataSample usually contain
list[:obj:`InstanceData`]: Detection results of the
input images. Each InstanceData usually contain
'pred_instances'. And the ``pred_instances`` usually
contains following keys.
Expand Down
Loading

0 comments on commit e45e8d3

Please sign in to comment.