-
Notifications
You must be signed in to change notification settings - Fork 444
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit-test for otx.engine.engine & otx.core.utils (#3158)
* Add unit-test for engine & core.utils * Fix import test
- Loading branch information
Showing
7 changed files
with
676 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
|
||
import pytest | ||
from mmpretrain.registry import MODELS | ||
from omegaconf import DictConfig | ||
from otx.core.utils.build import ( | ||
build_mm_model, | ||
get_classification_layers, | ||
get_default_num_async_infer_requests, | ||
modify_num_classes, | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def fxt_mm_config() -> DictConfig: | ||
return DictConfig( | ||
{ | ||
"backbone": { | ||
"version": "b0", | ||
"pretrained": True, | ||
"type": "OTXEfficientNet", | ||
}, | ||
"head": { | ||
"in_channels": 1280, | ||
"loss": { | ||
"loss_weight": 1.0, | ||
"type": "CrossEntropyLoss", | ||
}, | ||
"num_classes": 1000, | ||
"topk": (1, 5), | ||
"type": "LinearClsHead", | ||
}, | ||
"neck": { | ||
"type": "GlobalAveragePooling", | ||
}, | ||
"data_preprocessor": { | ||
"mean": [123.678, 116.28, 103.53], | ||
"std": [58.395, 57.12, 57.375], | ||
"to_rgb": False, | ||
"type": "ClsDataPreprocessor", | ||
}, | ||
"type": "ImageClassifier", | ||
}, | ||
) | ||
|
||
|
||
def test_build_mm_model(fxt_mm_config, mocker) -> None: | ||
model = build_mm_model(config=fxt_mm_config, model_registry=MODELS) | ||
assert model.__class__.__name__ == "ImageClassifier" | ||
|
||
mock_load_checkpoint = mocker.patch("mmengine.runner.load_checkpoint") | ||
model = build_mm_model(config=fxt_mm_config, model_registry=MODELS, load_from="path/to/weights.pth") | ||
mock_load_checkpoint.assert_called_once_with(model, "path/to/weights.pth", map_location="cpu") | ||
|
||
|
||
def test_get_default_num_async_infer_requests() -> None: | ||
# Test the get_default_num_async_infer_requests function. | ||
|
||
# Mock os.cpu_count() to return a specific value | ||
original_cpu_count = os.cpu_count | ||
os.cpu_count = lambda: 4 | ||
|
||
# Call the function and check the return value | ||
assert get_default_num_async_infer_requests() == 2 | ||
|
||
# Restore the original os.cpu_count() function | ||
os.cpu_count = original_cpu_count | ||
|
||
# Check the warning message | ||
with pytest.warns(UserWarning, match="Set the default number of OpenVINO inference requests"): | ||
get_default_num_async_infer_requests() | ||
|
||
|
||
def test_get_classification_layers(fxt_mm_config) -> None: | ||
expected_result = { | ||
"head.fc.weight": {"stride": 1, "num_extra_classes": 0}, | ||
"head.fc.bias": {"stride": 1, "num_extra_classes": 0}, | ||
} | ||
|
||
result = get_classification_layers(fxt_mm_config, MODELS) | ||
assert result == expected_result | ||
|
||
|
||
def test_modify_num_classes(): | ||
config = DictConfig({"num_classes": 10, "model": {"num_classes": 5}}) | ||
num_classes = 7 | ||
modify_num_classes(config, num_classes) | ||
assert config["num_classes"] == num_classes | ||
assert config["model"]["num_classes"] == num_classes | ||
|
||
config = DictConfig({"num_classes": 10, "model": {"num_classes": 5}}) | ||
num_classes = 7 | ||
modify_num_classes(config, num_classes) | ||
assert config["num_classes"] == num_classes | ||
assert config["model"]["num_classes"] == num_classes | ||
|
||
config = DictConfig({"model": {"layers": [{"units": 64}]}}) | ||
num_classes = 7 | ||
modify_num_classes(config, num_classes) | ||
assert "num_classes" not in config | ||
assert config["model"]["layers"][0]["units"] == 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from otx.core.utils.cache import TrainerArgumentsCache | ||
|
||
|
||
class TestTrainerArgumentsCache: | ||
def test_trainer_arguments_cache_init(self): | ||
cache = TrainerArgumentsCache(max_epochs=100, val_check_interval=0) | ||
assert cache.args == {"max_epochs": 100, "val_check_interval": 0} | ||
|
||
def test_trainer_arguments_cache_update(self): | ||
cache = TrainerArgumentsCache(max_epochs=100, val_check_interval=0) | ||
cache.update(max_epochs=1, val_check_interval=1.0) | ||
assert cache.args == {"max_epochs": 1, "val_check_interval": 1.0} | ||
|
||
cache.update(val_check_interval=None) | ||
assert cache.args == {"max_epochs": 1, "val_check_interval": 1.0} | ||
|
||
def test_trainer_arguments_cache_requires_update(self): | ||
cache = TrainerArgumentsCache(max_epochs=100, val_check_interval=0) | ||
assert not cache.requires_update(max_epochs=100, val_check_interval=0) | ||
assert cache.requires_update(max_epochs=1, val_check_interval=1.0) | ||
|
||
def test_trainer_arguments_cache_get_trainer_constructor_args(self): | ||
cache = TrainerArgumentsCache() | ||
args = cache.get_trainer_constructor_args() | ||
assert isinstance(args, set) | ||
assert "max_epochs" in args | ||
assert "val_check_interval" in args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from omegaconf import OmegaConf | ||
from otx.core.utils.config import inplace_num_classes, to_list, to_tuple | ||
|
||
|
||
def test_to_tuple() -> None: | ||
input_dict = { | ||
"a": [1, 2, 3], | ||
"b": { | ||
"c": (4, 5), | ||
"d": [6, 7, 8], | ||
}, | ||
"e": { | ||
"f": { | ||
"g": [9, 10], | ||
}, | ||
}, | ||
} | ||
|
||
expected_output = { | ||
"a": (1, 2, 3), | ||
"b": { | ||
"c": (4, 5), | ||
"d": (6, 7, 8), | ||
}, | ||
"e": { | ||
"f": { | ||
"g": (9, 10), | ||
}, | ||
}, | ||
} | ||
|
||
assert to_tuple(input_dict) == expected_output | ||
|
||
|
||
def test_to_list() -> None: | ||
input_dict = {} | ||
expected_output = {} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": (1, 2, 3), "b": {"c": (4, 5)}} | ||
expected_output = {"a": [1, 2, 3], "b": {"c": [4, 5]}} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": [1, 2, 3], "b": {"c": [4, 5]}} | ||
expected_output = {"a": [1, 2, 3], "b": {"c": [4, 5]}} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": (1, 2, 3), "b": [4, 5]} | ||
expected_output = {"a": [1, 2, 3], "b": [4, 5]} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": {"b": (1, 2, 3)}, "c": {"d": {"e": (4, 5)}}} | ||
expected_output = {"a": {"b": [1, 2, 3]}, "c": {"d": {"e": [4, 5]}}} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": {"b": [1, 2, 3]}, "c": {"d": {"e": [4, 5]}}} | ||
expected_output = {"a": {"b": [1, 2, 3]}, "c": {"d": {"e": [4, 5]}}} | ||
assert to_list(input_dict) == expected_output | ||
|
||
input_dict = {"a": {"b": (1, 2, 3)}, "c": {"d": [4, 5]}} | ||
expected_output = {"a": {"b": [1, 2, 3]}, "c": {"d": [4, 5]}} | ||
assert to_list(input_dict) == expected_output | ||
|
||
|
||
def test_inplace_num_classes() -> None: | ||
cfg = OmegaConf.create({"num_classes": 10, "model": {"num_classes": 5}}) | ||
inplace_num_classes(cfg, 20) | ||
assert cfg.num_classes == 20 | ||
assert cfg.model.num_classes == 20 | ||
|
||
cfg = OmegaConf.create([{"num_classes": 10}, {"num_classes": 5}]) | ||
inplace_num_classes(cfg, 20) | ||
assert cfg[0].num_classes == 20 | ||
assert cfg[1].num_classes == 20 | ||
|
||
cfg = OmegaConf.create({"model": {"num_classes": 10, "layers": [{"num_classes": 5}]}}) | ||
inplace_num_classes(cfg, 20) | ||
assert cfg.model.num_classes == 20 | ||
assert cfg.model.layers[0].num_classes == 20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import inspect | ||
from pathlib import Path | ||
|
||
import otx | ||
import pytest | ||
from otx.core.utils.imports import get_otx_root_path | ||
|
||
|
||
def test_get_otx_root_path(mocker): | ||
root_path = get_otx_root_path() | ||
assert isinstance(root_path, Path) | ||
otx_path = inspect.getfile(otx) | ||
assert root_path == Path(otx_path).parent | ||
|
||
with mocker.patch("importlib.import_module", return_value=None) and pytest.raises(ModuleNotFoundError): | ||
get_otx_root_path() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from otx.core.config.data import SamplerConfig | ||
from otx.core.utils.instantiators import ( | ||
instantiate_callbacks, | ||
instantiate_loggers, | ||
instantiate_sampler, | ||
partial_instantiate_class, | ||
) | ||
|
||
|
||
def test_instantiate_callbacks() -> None: | ||
callbacks_cfg = [ | ||
{ | ||
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint", | ||
"init_args": { | ||
"save_last": True, | ||
"save_top_k": 1, | ||
"monitor": "val_loss", | ||
"mode": "min", | ||
}, | ||
}, | ||
{ | ||
"class_path": "lightning.pytorch.callbacks.EarlyStopping", | ||
"init_args": { | ||
"monitor": "val_loss", | ||
"mode": "min", | ||
"patience": 3, | ||
}, | ||
}, | ||
] | ||
|
||
callbacks = instantiate_callbacks(callbacks_cfg=callbacks_cfg) | ||
assert len(callbacks) == 2 | ||
assert callbacks[0].__class__.__name__ == "ModelCheckpoint" | ||
assert callbacks[1].__class__.__name__ == "EarlyStopping" | ||
|
||
callbacks = instantiate_callbacks(callbacks_cfg=[]) | ||
assert len(callbacks) == 0 | ||
|
||
|
||
def test_instantiate_loggers() -> None: | ||
logger_cfg = [ | ||
{ | ||
"class_path": "lightning.pytorch.loggers.TensorBoardLogger", | ||
"init_args": { | ||
"save_dir": "logs", | ||
"name": "tb_logs", | ||
}, | ||
}, | ||
] | ||
|
||
loggers = instantiate_loggers(logger_cfg=logger_cfg) | ||
assert len(loggers) == 1 | ||
assert loggers[0].__class__.__name__ == "TensorBoardLogger" | ||
|
||
loggers = instantiate_loggers(logger_cfg=None) | ||
assert len(loggers) == 0 | ||
|
||
|
||
def test_partial_instantiate_class() -> None: | ||
init = { | ||
"class_path": "torch.optim.SGD", | ||
"init_args": { | ||
"lr": 0.0049, | ||
"momentum": 0.9, | ||
"weight_decay": 0.0001, | ||
}, | ||
} | ||
|
||
partial = partial_instantiate_class(init=init) | ||
assert len(partial) == 1 | ||
assert partial[0].__class__.__name__ == "partial" | ||
assert partial[0].func.__name__ == "SGD" | ||
assert partial[0].keywords == init["init_args"] | ||
assert partial[0].args == () | ||
|
||
partial = partial_instantiate_class(init=None) | ||
assert partial is None | ||
|
||
|
||
def test_instantiate_sampler(mocker) -> None: | ||
sampler_cfg = SamplerConfig( | ||
class_path="torch.utils.data.WeightedRandomSampler", | ||
init_args={ | ||
"num_samples": 10, | ||
"replacement": True, | ||
}, | ||
) | ||
|
||
mock_dataset = mocker.MagicMock() | ||
sampler = instantiate_sampler(sampler_config=sampler_cfg, dataset=mock_dataset) | ||
assert sampler.__class__.__name__ == "WeightedRandomSampler" | ||
assert sampler.num_samples == sampler_cfg.init_args["num_samples"] | ||
assert sampler.replacement == sampler_cfg.init_args["replacement"] |
Oops, something went wrong.