Skip to content

Commit

Permalink
Add unit-test for algo/detection (#3177)
Browse files Browse the repository at this point in the history
Add algo/detection unit-tests
  • Loading branch information
sungmanc authored Mar 21, 2024
1 parent 3445eaf commit 2161e28
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/otx/algo/detection/losses/cross_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def cross_sigmoid_focal_loss(
targets = F.one_hot(targets, num_classes=inputs_size + 1)
targets = targets[:, :inputs_size]
calculate_loss_func = py_sigmoid_focal_loss

loss = calculate_loss_func(
inputs,
targets,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/algo/detection/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of custom loss modules of OTX Detection task."""
47 changes: 47 additions & 0 deletions tests/unit/algo/detection/losses/test_cross_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from unittest.mock import patch

import pytest
import torch
from otx.algo.detection.losses.cross_focal_loss import CrossSigmoidFocalLoss, OrdinaryFocalLoss


class TestCrossFocalLoss:
@pytest.fixture()
def mock_tensor(self) -> torch.Tensor:
return torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)

@pytest.fixture()
def mock_targets(self) -> torch.Tensor:
return torch.tensor([0, 1], dtype=torch.long)

@pytest.fixture()
def mock_weights(self) -> torch.Tensor:
return torch.tensor([0.5, 0.5], dtype=torch.float32)

@pytest.fixture()
def mock_valid_label_mask(self) -> torch.Tensor:
return torch.tensor([1, 1], dtype=torch.float32)

def test_cross_focal_forward(self, mock_tensor, mock_targets, mock_weights, mock_valid_label_mask):
with patch(
"otx.algo.detection.losses.cross_focal_loss.py_sigmoid_focal_loss",
return_value=torch.tensor(0.5),
) as mock_loss_func:
loss_fn = CrossSigmoidFocalLoss()

loss = loss_fn(
pred=mock_tensor,
targets=mock_targets,
weight=mock_weights,
valid_label_mask=mock_valid_label_mask,
reduction_override="mean",
)
mock_loss_func.assert_called()
assert loss.item() == pytest.approx(0.5, 0.1), "Loss did not match expected value."

def test_ordinary_focal_forward(self, mock_tensor, mock_targets, mock_weights):
loss_fn = OrdinaryFocalLoss(gamma=1.5)

loss = loss_fn(inputs=mock_tensor, targets=mock_targets, label_weights=mock_weights)

assert loss >= 0, "Loss should be non-negative."
25 changes: 25 additions & 0 deletions tests/unit/algo/detection/test_atss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of OTX SSD architecture."""

import pytest
from otx.algo.detection.atss import ATSS, ATSSR50FPN
from otx.algo.utils.support_otx_v1 import OTXv1Helper


class TestATSS:
@pytest.mark.parametrize(
"model",
[
ATSS(num_classes=2, variant="mobilenetv2"),
ATSS(num_classes=2, variant="r50_fpn"),
ATSS(num_classes=2, variant="resnext101"),
ATSSR50FPN(num_classes=2),
],
)
def test(self, model, mocker) -> None:
mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt")
model.load_from_otx_v1_ckpt({})
mock_load_ckpt.assert_called_once_with({}, "model.model.")

assert isinstance(model._export_parameters, dict)
20 changes: 20 additions & 0 deletions tests/unit/algo/detection/test_rtmdet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of OTX SSD architecture."""

import pytest
from otx.algo.detection.rtmdet import RTMDet
from otx.algo.utils.support_otx_v1 import OTXv1Helper


class TestRTMDet:
@pytest.fixture()
def fxt_model(self) -> RTMDet:
return RTMDet(num_classes=3, variant="tiny")

def test(self, fxt_model, mocker) -> None:
mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt")
fxt_model.load_from_otx_v1_ckpt({})
mock_load_ckpt.assert_called_once_with({}, "model.model.")

assert isinstance(fxt_model._export_parameters, dict)
26 changes: 26 additions & 0 deletions tests/unit/algo/detection/test_yolox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""Test of OTX SSD architecture."""

import pytest
from otx.algo.detection.yolox import YoloX, YoloXTiny
from otx.algo.utils.support_otx_v1 import OTXv1Helper


class TestYOLOX:
@pytest.mark.parametrize(
"model",
[
YoloX(num_classes=2, variant="l"),
YoloX(num_classes=2, variant="s"),
YoloX(num_classes=2, variant="tiny"),
YoloX(num_classes=2, variant="x"),
YoloXTiny(num_classes=2),
],
)
def test(self, model, mocker) -> None:
mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt")
model.load_from_otx_v1_ckpt({})
mock_load_ckpt.assert_called_once_with({}, "model.model.")

assert isinstance(model._export_parameters, dict)

0 comments on commit 2161e28

Please sign in to comment.