From 2161e28f7a318509a4608f40582e3ab69e5dba04 Mon Sep 17 00:00:00 2001 From: Sungman Cho Date: Thu, 21 Mar 2024 17:28:46 +0900 Subject: [PATCH] Add unit-test for algo/detection (#3177) Add algo/detection unit-tests --- .../algo/detection/losses/cross_focal_loss.py | 1 - tests/unit/algo/detection/losses/__init__.py | 3 ++ .../detection/losses/test_cross_focal_loss.py | 47 +++++++++++++++++++ tests/unit/algo/detection/test_atss.py | 25 ++++++++++ tests/unit/algo/detection/test_rtmdet.py | 20 ++++++++ tests/unit/algo/detection/test_yolox.py | 26 ++++++++++ 6 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 tests/unit/algo/detection/losses/__init__.py create mode 100644 tests/unit/algo/detection/losses/test_cross_focal_loss.py create mode 100644 tests/unit/algo/detection/test_atss.py create mode 100644 tests/unit/algo/detection/test_rtmdet.py create mode 100644 tests/unit/algo/detection/test_yolox.py diff --git a/src/otx/algo/detection/losses/cross_focal_loss.py b/src/otx/algo/detection/losses/cross_focal_loss.py index abb0668109a..6931fdf38ca 100644 --- a/src/otx/algo/detection/losses/cross_focal_loss.py +++ b/src/otx/algo/detection/losses/cross_focal_loss.py @@ -42,7 +42,6 @@ def cross_sigmoid_focal_loss( targets = F.one_hot(targets, num_classes=inputs_size + 1) targets = targets[:, :inputs_size] calculate_loss_func = py_sigmoid_focal_loss - loss = calculate_loss_func( inputs, targets, diff --git a/tests/unit/algo/detection/losses/__init__.py b/tests/unit/algo/detection/losses/__init__.py new file mode 100644 index 00000000000..8c0686e6e61 --- /dev/null +++ b/tests/unit/algo/detection/losses/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of custom loss modules of OTX Detection task.""" diff --git a/tests/unit/algo/detection/losses/test_cross_focal_loss.py b/tests/unit/algo/detection/losses/test_cross_focal_loss.py new file mode 100644 index 00000000000..5173cc57889 --- /dev/null +++ b/tests/unit/algo/detection/losses/test_cross_focal_loss.py @@ -0,0 +1,47 @@ +from unittest.mock import patch + +import pytest +import torch +from otx.algo.detection.losses.cross_focal_loss import CrossSigmoidFocalLoss, OrdinaryFocalLoss + + +class TestCrossFocalLoss: + @pytest.fixture() + def mock_tensor(self) -> torch.Tensor: + return torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + + @pytest.fixture() + def mock_targets(self) -> torch.Tensor: + return torch.tensor([0, 1], dtype=torch.long) + + @pytest.fixture() + def mock_weights(self) -> torch.Tensor: + return torch.tensor([0.5, 0.5], dtype=torch.float32) + + @pytest.fixture() + def mock_valid_label_mask(self) -> torch.Tensor: + return torch.tensor([1, 1], dtype=torch.float32) + + def test_cross_focal_forward(self, mock_tensor, mock_targets, mock_weights, mock_valid_label_mask): + with patch( + "otx.algo.detection.losses.cross_focal_loss.py_sigmoid_focal_loss", + return_value=torch.tensor(0.5), + ) as mock_loss_func: + loss_fn = CrossSigmoidFocalLoss() + + loss = loss_fn( + pred=mock_tensor, + targets=mock_targets, + weight=mock_weights, + valid_label_mask=mock_valid_label_mask, + reduction_override="mean", + ) + mock_loss_func.assert_called() + assert loss.item() == pytest.approx(0.5, 0.1), "Loss did not match expected value." + + def test_ordinary_focal_forward(self, mock_tensor, mock_targets, mock_weights): + loss_fn = OrdinaryFocalLoss(gamma=1.5) + + loss = loss_fn(inputs=mock_tensor, targets=mock_targets, label_weights=mock_weights) + + assert loss >= 0, "Loss should be non-negative." diff --git a/tests/unit/algo/detection/test_atss.py b/tests/unit/algo/detection/test_atss.py new file mode 100644 index 00000000000..787eb5cf5cb --- /dev/null +++ b/tests/unit/algo/detection/test_atss.py @@ -0,0 +1,25 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of OTX SSD architecture.""" + +import pytest +from otx.algo.detection.atss import ATSS, ATSSR50FPN +from otx.algo.utils.support_otx_v1 import OTXv1Helper + + +class TestATSS: + @pytest.mark.parametrize( + "model", + [ + ATSS(num_classes=2, variant="mobilenetv2"), + ATSS(num_classes=2, variant="r50_fpn"), + ATSS(num_classes=2, variant="resnext101"), + ATSSR50FPN(num_classes=2), + ], + ) + def test(self, model, mocker) -> None: + mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt") + model.load_from_otx_v1_ckpt({}) + mock_load_ckpt.assert_called_once_with({}, "model.model.") + + assert isinstance(model._export_parameters, dict) diff --git a/tests/unit/algo/detection/test_rtmdet.py b/tests/unit/algo/detection/test_rtmdet.py new file mode 100644 index 00000000000..15949f76701 --- /dev/null +++ b/tests/unit/algo/detection/test_rtmdet.py @@ -0,0 +1,20 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of OTX SSD architecture.""" + +import pytest +from otx.algo.detection.rtmdet import RTMDet +from otx.algo.utils.support_otx_v1 import OTXv1Helper + + +class TestRTMDet: + @pytest.fixture() + def fxt_model(self) -> RTMDet: + return RTMDet(num_classes=3, variant="tiny") + + def test(self, fxt_model, mocker) -> None: + mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt") + fxt_model.load_from_otx_v1_ckpt({}) + mock_load_ckpt.assert_called_once_with({}, "model.model.") + + assert isinstance(fxt_model._export_parameters, dict) diff --git a/tests/unit/algo/detection/test_yolox.py b/tests/unit/algo/detection/test_yolox.py new file mode 100644 index 00000000000..4bc205d3e24 --- /dev/null +++ b/tests/unit/algo/detection/test_yolox.py @@ -0,0 +1,26 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of OTX SSD architecture.""" + +import pytest +from otx.algo.detection.yolox import YoloX, YoloXTiny +from otx.algo.utils.support_otx_v1 import OTXv1Helper + + +class TestYOLOX: + @pytest.mark.parametrize( + "model", + [ + YoloX(num_classes=2, variant="l"), + YoloX(num_classes=2, variant="s"), + YoloX(num_classes=2, variant="tiny"), + YoloX(num_classes=2, variant="x"), + YoloXTiny(num_classes=2), + ], + ) + def test(self, model, mocker) -> None: + mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_det_ckpt") + model.load_from_otx_v1_ckpt({}) + mock_load_ckpt.assert_called_once_with({}, "model.model.") + + assert isinstance(model._export_parameters, dict)