diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index aef2e8d2a20..6f37dfa5e21 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -358,7 +358,6 @@ def __init__( ) self.image_size = (1, 3, 864, 864) self.tile_image_size = self.image_size - self._register_load_state_dict_pre_hook(self._set_anchors_hook) def _create_model(self) -> nn.Module: from mmdet.models.data_preprocessors import ( @@ -410,6 +409,10 @@ def setup(self, stage: str) -> None: anchor_generator.widths = new_anchors[0] anchor_generator.heights = new_anchors[1] anchor_generator.gen_base_anchors() + self.hparams["ssd_anchors"] = { + "heights": anchor_generator.heights, + "widths": anchor_generator.widths, + } def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGeneratorClustered) -> tuple | None: """Get new anchors for SSD from OTXDataset.""" @@ -521,19 +524,6 @@ def get_classification_layers( classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors} return classification_layers - def state_dict(self, *args, **kwargs) -> dict[str, Any]: - """Return state dictionary of model entity with anchor information. - - Returns: - A dictionary containing SSD state. - - """ - state_dict = super().state_dict(*args, **kwargs) - anchor_generator = self.model.bbox_head.anchor_generator - anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths} - state_dict["model.model.anchors"] = anchors - return state_dict - def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None: """Modify input state_dict according to class name matching before weight loading.""" model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes) @@ -588,15 +578,16 @@ def _exporter(self) -> OTXModelExporter: output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, ) - def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None: - """Pre hook for pop anchor statistics from checkpoint state_dict.""" - anchors = state_dict.pop("model.model.anchors", None) - if anchors is not None: + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on load checkpoint.""" + if (hparams := checkpoint.get("hyper_parameters")) and (anchors := hparams.get("ssd_anchors", None)): anchor_generator = self.model.bbox_head.anchor_generator anchor_generator.widths = anchors["widths"] anchor_generator.heights = anchors["heights"] anchor_generator.gen_base_anchors() + return super().on_load_checkpoint(checkpoint) + def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix) diff --git a/src/otx/core/data/dataset/anomaly/dataset.py b/src/otx/core/data/dataset/anomaly/dataset.py index 1d9d6bc443a..8a7a90d5a69 100644 --- a/src/otx/core/data/dataset/anomaly/dataset.py +++ b/src/otx/core/data/dataset/anomaly/dataset.py @@ -26,6 +26,7 @@ from otx.core.data.entity.base import ImageInfo from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase from otx.core.types.image import ImageColorChannel +from otx.core.types.label import LabelInfo from otx.core.types.task import OTXTaskType @@ -53,6 +54,7 @@ def __init__( image_color_channel, stack_images, ) + self.label_info = LabelInfo(label_names=["Normal", "Anomaly"], label_groups=[["Normal", "Anomaly"]]) def _get_item_impl( self, diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 4ece94158e4..775546b51ed 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -26,6 +26,7 @@ ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingDataEntity, ) +from otx.core.types.label import NullLabelInfo from otx.core.utils.mask_util import polygon_to_bitmap from .base import OTXDataset, Transforms @@ -61,6 +62,8 @@ def __init__( # if using only point prompt self.prob = 0.0 + self.label_info = NullLabelInfo() + def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name) img = item.media_as(dmImage) @@ -189,6 +192,8 @@ def __init__( # if using only point prompt self.prob = 0.0 + self.label_info = NullLabelInfo() + def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None: item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name) img = item.media_as(dmImage) diff --git a/src/otx/core/metrics/fmeasure.py b/src/otx/core/metrics/fmeasure.py index 56bcb9853b7..e761eccefdd 100644 --- a/src/otx/core/metrics/fmeasure.py +++ b/src/otx/core/metrics/fmeasure.py @@ -657,7 +657,7 @@ def __init__( self._f_measure_per_nms: dict | None = None self._best_confidence_threshold: float | None = None self._best_nms_threshold: float | None = None - self._f_measure = 0.0 + self._f_measure = float("-inf") self.reset() diff --git a/src/otx/core/model/anomaly.py b/src/otx/core/model/anomaly.py index c10cdf50438..edad71489ae 100644 --- a/src/otx/core/model/anomaly.py +++ b/src/otx/core/model/anomaly.py @@ -33,8 +33,6 @@ from otx.core.types.task import OTXTaskType if TYPE_CHECKING: - from collections import OrderedDict - from anomalib.metrics import AnomalibMetricCollection from anomalib.metrics.threshold import BaseThreshold from lightning.pytorch import Trainer @@ -159,6 +157,22 @@ def __init__(self) -> None: self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on saving checkpoint.""" + super().on_save_checkpoint(checkpoint) # type: ignore[misc] + + attrs = ["_task_type", "_input_size", "mean_values", "scale_values", "image_threshold", "pixel_threshold"] + + checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs} + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on loading checkpoint.""" + super().on_load_checkpoint(checkpoint) # type: ignore[misc] + + if anomaly_attrs := checkpoint.get("anomaly", None): + for key, value in anomaly_attrs.items(): + setattr(self, key, value) + @property def input_size(self) -> tuple[int, int]: """Returns the input size of the model. @@ -238,7 +252,7 @@ def trainable_model(self) -> str | None: def setup(self, stage: str | None = None) -> None: """Setup the model.""" super().setup(stage) # type: ignore[misc] - if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"): + if stage == "fit" and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"): if hasattr(self.trainer.datamodule.config, "test_subset"): self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.test_subset.transforms) elif hasattr(self.trainer.datamodule.config, "val_subset"): @@ -327,24 +341,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch. return optimizer(params=params) return super().configure_optimizers() # type: ignore[misc] - def state_dict(self) -> dict[str, Any]: - """Return state dictionary of model entity with meta information. - - Returns: - A dictionary containing datamodule state. - - """ - state_dict = super().state_dict() # type: ignore[misc] - # This is defined in OTXModel - state_dict["label_info"] = self.label_info # type: ignore[attr-defined] - return state_dict - - def load_state_dict(self, ckpt: OrderedDict[str, Any], *args, **kwargs) -> None: - """Pass the checkpoint to the anomaly model.""" - ckpt = ckpt.get("state_dict", ckpt) - ckpt.pop("label_info", None) # [TODO](ashwinvaidya17): Revisit this method when OTXModel is the lightning model - return super().load_state_dict(ckpt, *args, **kwargs) # type: ignore[misc] - def forward( self, inputs: AnomalyModelInputs, diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 7dfd1632af2..a4bf4943e0e 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -27,6 +27,8 @@ from torch.optim.sgd import SGD from torchmetrics import Metric, MetricCollection +from otx import __version__ +from otx.core.config.data import TileConfig from otx.core.data.entity.base import ( OTXBatchLossEntity, T_OTXBatchDataEntity, @@ -113,6 +115,8 @@ def __init__( self.torch_compile = torch_compile self._explain_mode = False + self._tile_config: TileConfig | None = None + # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters(logger=False, ignore=["model", "optimizer", "scheduler", "metric"]) @@ -336,16 +340,54 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa self.log(log_metric_name, value, sync_dist=True, prog_bar=True) - def state_dict(self) -> dict[str, Any]: - """Return state dictionary of model entity with meta information. + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on saving checkpoint.""" + super().on_save_checkpoint(checkpoint) - Returns: - A dictionary containing datamodule state. + checkpoint["label_info"] = self.label_info + checkpoint["otx_version"] = __version__ - """ - state_dict = super().state_dict() - state_dict["label_info"] = self.label_info - return state_dict + if self._tile_config: + checkpoint["tile_config"] = self._tile_config + + def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: + """Callback on loading checkpoint.""" + super().on_load_checkpoint(checkpoint) + + if ckpt_label_info := checkpoint.get("label_info", None): + self._label_info = ckpt_label_info + + if ckpt_tile_config := checkpoint.get("tile_config", None): + self._tile_config = ckpt_tile_config + + def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None: + """Load state dict incrementally.""" + ckpt_label_info: LabelInfo | None = ckpt.get("label_info", None) + + if ckpt_label_info is None: + msg = "Checkpoint should have `label_info`." + raise ValueError(msg, ckpt_label_info) + + if ckpt_label_info != self.label_info: + msg = ( + "Load model state dictionary incrementally: " + f"Label info from checkpoint: {ckpt_label_info} -> " + f"Label info from training data: {self.label_info}" + ) + logger.info(msg) + self.register_load_state_dict_pre_hook( + self.label_info.label_names, + ckpt_label_info.label_names, + ) + + # Model weights + state_dict: dict[str, Any] = ckpt.get("state_dict", None) + + if ckpt_label_info is None: + msg = "Checkpoint should have `state_dict`." + raise ValueError(msg, ckpt_label_info) + + self.load_state_dict(state_dict, *args, **kwargs) def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: """Load state dictionary from checkpoint state dictionary. @@ -364,23 +406,6 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None: else: state_dict = ckpt - ckpt_label_info = state_dict.pop("label_info", None) - - if ckpt_label_info and self.label_info is None: - msg = ( - "`state_dict` to load has `label_info`, but the current model has no `label_info`. " - "It is recommended to set proper `label_info` for the incremental learning case." - ) - warnings.warn(msg, stacklevel=2) - if ckpt_label_info and self.label_info and ckpt_label_info != self.label_info: - logger.warning( - f"Data classes from checkpoint: {ckpt_label_info.label_names} -> " - f"Data classes from training data: {self.label_info.label_names}", - ) - self.register_load_state_dict_pre_hook( - self.label_info.label_names, - ckpt_label_info.label_names, - ) return super().load_state_dict(state_dict, *args, **kwargs) def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict: @@ -698,6 +723,20 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None: if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable): self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable) + @property + def tile_config(self) -> TileConfig: + """Get tiling configurations.""" + if self._tile_config is None: + msg = "This task type does not support tiling." + raise RuntimeError(msg) + + return self._tile_config + + @tile_config.setter + def tile_config(self, tile_config: TileConfig) -> None: + """Set tiling configurations.""" + self._tile_config = tile_config + class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the OpenVINO model. diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 412f5d88bc0..8f4a609931c 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -14,12 +14,11 @@ from openvino.model_api.tilers import DetectionTiler from torchvision import tv_tensors -from otx.algo.explain.explain_algo import get_feature_vector from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity from otx.core.data.entity.tile import TileBatchDetDataEntity -from otx.core.metrics import MetricInput +from otx.core.metrics import MetricCallable, MetricInput from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel from otx.core.schedulers import LRSchedulerListCallable @@ -37,8 +36,6 @@ from torch import nn from torchmetrics import Metric - from otx.core.metrics import MetricCallable - class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): """Base class for the detection models used in OTX.""" @@ -58,7 +55,7 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - self.tile_config = TileConfig() + self._tile_config = TileConfig() def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: """Unpack detection tiles. @@ -170,14 +167,32 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa class ExplainableOTXDetModel(OTXDetectionModel): """OTX detection model which can attach a XAI (Explainable AI) branch.""" + def __init__( + self, + num_classes: int, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MeanAPCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + num_classes=num_classes, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + from otx.algo.explain.explain_algo import get_feature_vector + + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + def forward_explain( self, inputs: DetBatchDataEntity, ) -> DetBatchPredEntity: """Model forward function.""" - self.model.feature_vector_fn = get_feature_vector - self.model.explain_fn = self.get_explain_fn() - # If customize_inputs is overridden outputs = ( self._forward_explain_detection(self.model, **self._customize_inputs(inputs)) diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index e75967562fa..cc6936a8719 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -16,7 +16,6 @@ from openvino.model_api.tilers import InstanceSegmentationTiler from torchvision import tv_tensors -from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo, get_feature_vector from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity @@ -67,7 +66,7 @@ def __init__( metric=metric, torch_compile=torch_compile, ) - self.tile_config = TileConfig() + self._tile_config = TileConfig() def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity: """Unpack instance segmentation tiles. @@ -209,14 +208,32 @@ def _convert_pred_entity_to_compute_metric( class ExplainableOTXInstanceSegModel(OTXInstanceSegModel): """OTX Instance Segmentation model which can attach a XAI (Explainable AI) branch.""" + def __init__( + self, + num_classes: int, + optimizer: OptimizerCallable = DefaultOptimizerCallable, + scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable, + metric: MetricCallable = MaskRLEMeanAPCallable, + torch_compile: bool = False, + ) -> None: + super().__init__( + num_classes=num_classes, + optimizer=optimizer, + scheduler=scheduler, + metric=metric, + torch_compile=torch_compile, + ) + + from otx.algo.explain.explain_algo import get_feature_vector + + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + def forward_explain( self, inputs: InstanceSegBatchDataEntity, ) -> InstanceSegBatchPredEntity: """Model forward function.""" - self.model.feature_vector_fn = get_feature_vector - self.model.explain_fn = self.get_explain_fn() - # If customize_inputs is overridden outputs = ( self._forward_explain_inst_seg(self.model, **self._customize_inputs(inputs)) @@ -271,6 +288,8 @@ def _forward_explain_inst_seg( def get_explain_fn(self) -> Callable: """Returns explain function.""" + from otx.algo.explain.explain_algo import MaskRCNNExplainAlgo + explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes) return explainer.func diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 7426f69c2aa..2307474a1ef 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -189,6 +189,7 @@ def __init__( metric=metric, torch_compile=torch_compile, ) + self._label_info = NullLabelInfo() @property def _exporter(self) -> OTXModelExporter: @@ -298,6 +299,7 @@ def __init__( metric=metric, torch_compile=torch_compile, ) + self._label_info = NullLabelInfo() @property def _exporter(self) -> OTXModelExporter: diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index f1e636c60ba..4a90faa3ea2 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -180,6 +180,7 @@ def train( metric: MetricCallable | None = None, run_hpo: bool = False, hpo_config: HpoConfig | None = None, + checkpoint: PathLike | None = None, **kwargs, ) -> dict[str, Any]: """Trains the model using the provided LightningModule and OTXDataModule. @@ -199,6 +200,7 @@ def train( metric callable. It will temporarilly change the evaluation metric for the validation and test. run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. hpo_config (HpoConfig | None, optional): Configuration for HPO. + checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. **kwargs: Additional keyword arguments for pl.Trainer configuration. Returns: @@ -234,6 +236,8 @@ def train( otx train --data_root --config ``` """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + if run_hpo: if hpo_config is None: hpo_config = HpoConfig() @@ -241,7 +245,7 @@ def train( if best_config is not None: update_hyper_parameter(self, best_config) if best_trial_weight is not None: - self.checkpoint = best_trial_weight + checkpoint = best_trial_weight resume = True if seed is not None: @@ -258,7 +262,7 @@ def train( ) fit_kwargs: dict[str, Any] = {} - # NOTE Model's label info should be converted datamodule's label info before ckpt loading + # NOTE: Model's label info should be converted datamodule's label info before ckpt loading # This is due to smart weight loading check label name as well as number of classes. if self.model.label_info != self.datamodule.label_info: # TODO (vinnamki): Revisit label_info logic to make it cleaner @@ -269,12 +273,17 @@ def train( logging.warning(msg) self.model.label_info = self.datamodule.label_info - if resume: - fit_kwargs["ckpt_path"] = self.checkpoint - elif self.checkpoint is not None: - loaded_checkpoint = torch.load(self.checkpoint) - # loaded checkpoint have keys (OTX1.5): model, config, labels, input_size, VERSION - self.model.load_state_dict(loaded_checkpoint) + if resume and checkpoint: + # NOTE: If both `resume` and `checkpoint` are provided, + # load the entire model state from the checkpoint using the pl.Trainer's API. + fit_kwargs["ckpt_path"] = checkpoint + elif not resume and checkpoint: + # NOTE: If `resume` is not enabled but `checkpoint` is provided, + # load the model state from the checkpoint incrementally. + # This means only the model weights are loaded. If there is a mismatch in label_info, + # perform incremental weight loading for the model's classification layer. + ckpt = torch.load(checkpoint) + self.model.load_state_dict_incrementally(ckpt) with override_metric_callable(model=self.model, new_metric_callable=metric) as model: self.trainer.fit( @@ -333,20 +342,6 @@ def test( otx test --config --checkpoint ``` """ - # NOTE Model's label info should be converted datamodule's label info before ckpt loading - # This is due to smart weight loading check label name as well as number of classes. - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint datamodule = datamodule if datamodule is not None else self.datamodule @@ -366,8 +361,18 @@ def test( # NOTE, trainer.test takes only lightning based checkpoint. # So, it can't take the OTX1.x checkpoint. if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = self.model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a test pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) self._build_trainer(**kwargs) @@ -423,20 +428,6 @@ def predict( """ from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity - # NOTE Model's label info should be converted datamodule's label info before ckpt loading - # This is due to smart weight loading check label name as well as number of classes. - if self.model.label_info != self.datamodule.label_info: - # TODO (vinnamki): Revisit label_info logic to make it cleaner - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - # TODO (vinnamki): This should be changed to raise an error if not equivalent in case of test - # raise ValueError() - model = self.model checkpoint = checkpoint if checkpoint is not None else self.checkpoint @@ -451,8 +442,18 @@ def predict( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = self.model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a predict pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) self._build_trainer(**kwargs) @@ -516,11 +517,12 @@ def export( --checkpoint --export_precision FP16 --export_format ONNX ``` """ - ckpt_path = str(checkpoint) if checkpoint is not None else self.checkpoint - if ckpt_path is None: + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + + if checkpoint is None: msg = "To make export, checkpoint must be specified." raise RuntimeError(msg) - is_ir_ckpt = Path(ckpt_path).suffix in [".xml"] + is_ir_ckpt = Path(checkpoint).suffix in [".xml"] if is_ir_ckpt and export_format != OTXExportFormatType.EXPORTABLE_CODE: msg = ( @@ -538,10 +540,9 @@ def export( ) if not is_ir_ckpt: + model_cls = self.model.__class__ + self.model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, map_location="cpu") self.model.eval() - loaded_checkpoint = torch.load(ckpt_path) - self.model.label_info = loaded_checkpoint["state_dict"]["label_info"] - self.model.load_state_dict(loaded_checkpoint) self.model.explain_mode = explain exported_model_path = self.model.export( @@ -679,11 +680,19 @@ def explain( datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - model.label_info = datamodule.label_info - if checkpoint is not None and not is_ir_ckpt: - loaded_checkpoint = torch.load(checkpoint) - model.load_state_dict(loaded_checkpoint) + model_cls = model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a explain pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) model.explain_mode = True @@ -726,7 +735,7 @@ def from_config( Defaults to None. If work_dir is None, use the work_dir from the configuration file. kwargs: Arguments that can override the engine's arguments. - Returns:s + Returns: Engine: An instance of the Engine class. Example: @@ -765,10 +774,24 @@ def from_config( ) warn(msg, stacklevel=1) + if (datamodule := instantiated_config.get("data")) is None: + msg = "Cannot instantiate datamodule from config." + raise ValueError(msg) + if not isinstance(datamodule, OTXDataModule): + raise TypeError(datamodule) + + if (model := instantiated_config.get("model")) is None: + msg = "Cannot instantiate model from config." + raise ValueError(msg) + if not isinstance(model, OTXModel): + raise TypeError(model) + + model.label_info = datamodule.label_info + return cls( work_dir=instantiated_config.get("work_dir", work_dir), - datamodule=instantiated_config.get("data"), - model=instantiated_config.get("model"), + datamodule=datamodule, + model=model, **engine_kwargs, ) diff --git a/tests/conftest.py b/tests/conftest.py index 2f115d2aac4..3b9717b53ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -251,7 +251,7 @@ def fxt_clean_up_mem_cache(): # TODO(Jaeguk): Add cpu param when OTX can run integration test parallelly for each task. -@pytest.fixture(params=[pytest.param("gpu", marks=pytest.mark.gpu)]) +@pytest.fixture(scope="module", params=[pytest.param("gpu", marks=pytest.mark.gpu)]) def fxt_accelerator(request: pytest.FixtureRequest) -> str: return request.param diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index bfb17020f0f..f1a588c3fbe 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -4,7 +4,7 @@ from pathlib import Path -import numpy as np +import cv2 import pytest import yaml from otx.core.types.task import OTXTaskType @@ -13,34 +13,19 @@ from tests.utils import run_main -@pytest.mark.parametrize( - "recipe", - pytest.RECIPE_LIST, +@pytest.fixture( + params=pytest.RECIPE_LIST, ids=lambda x: "/".join(Path(x).parts[-2:]), ) -def test_otx_e2e( - recipe: str, - tmp_path: Path, +def fxt_trained_model( fxt_accelerator: str, fxt_target_dataset_per_task: dict, fxt_cli_override_command_per_task: dict, fxt_open_subprocess: bool, -) -> None: - """ - Test OTX CLI e2e commands. - - - 'otx train' with 2 epochs training - - 'otx test' with output checkpoint from 'otx train' - - 'otx export' with output checkpoint from 'otx train' - - 'otx test' with the exported to ONNX/IR model - - Args: - recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') - tmp_path (Path): The temporary path for storing the training outputs. - - Returns: - None - """ + request: pytest.FixtureRequest, + tmp_path, +): + recipe = request.param task = recipe.split("/")[-2] model_name = recipe.split("/")[-1].split(".")[0] @@ -64,6 +49,34 @@ def test_otx_e2e( run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) + return recipe, task, model_name, tmp_path_train + + +def test_otx_e2e( + fxt_trained_model, + fxt_accelerator: str, + fxt_target_dataset_per_task: dict, + fxt_cli_override_command_per_task: dict, + fxt_open_subprocess: bool, + tmp_path: Path, +) -> None: + """ + Test OTX CLI e2e commands. + + - 'otx train' with 2 epochs training + - 'otx test' with output checkpoint from 'otx train' + - 'otx export' with output checkpoint from 'otx train' + - 'otx test' with the exported to ONNX/IR model + + Args: + recipe (str): The recipe to use for training. (eg. 'classification/otx_mobilenet_v3_large.yaml') + tmp_path (Path): The temporary path for storing the training outputs. + + Returns: + None + """ + recipe, task, model_name, tmp_path_train = fxt_trained_model + outputs_dir = tmp_path_train / "outputs" latest_dir = max( (p for p in outputs_dir.iterdir() if p.is_dir() and p.name != ".latest"), @@ -79,9 +92,8 @@ def test_otx_e2e( assert "data" in train_output_config assert "engine" in train_output_config assert (latest_dir / "csv").exists() - assert (latest_dir / "checkpoints").exists() - ckpt_files = list((latest_dir / "checkpoints").glob(pattern="epoch_*.ckpt")) - assert len(ckpt_files) > 0 + ckpt_file = latest_dir / "best_checkpoint.ckpt" + assert ckpt_file.exists() # 2) otx test tmp_path_test = tmp_path / f"otx_test_{model_name}" @@ -98,7 +110,7 @@ def test_otx_e2e( fxt_accelerator, *fxt_cli_override_command_per_task[task], "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), ] run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess) @@ -154,7 +166,7 @@ def test_otx_e2e( str(tmp_path_test / "outputs" / fmt), *overrides, "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), "--export_format", f"{fmt}", ] @@ -232,7 +244,7 @@ def test_otx_e2e( str(tmp_path_test / "outputs" / fmt), *fxt_cli_override_command_per_task[task], "--checkpoint", - str(ckpt_files[-1]), + str(ckpt_file), "--export_format", f"{fmt}", "--explain", @@ -250,18 +262,13 @@ def test_otx_e2e( assert (fmt_latest_dir / f"{format_to_file[fmt]}").exists() -@pytest.mark.parametrize( - "recipe", - pytest.RECIPE_LIST, - ids=lambda x: "/".join(Path(x).parts[-2:]), -) def test_otx_explain_e2e( - recipe: str, - tmp_path: Path, + fxt_trained_model, fxt_accelerator: str, fxt_target_dataset_per_task: dict, fxt_cli_override_command_per_task: dict, fxt_open_subprocess: bool, + tmp_path: Path, ) -> None: """ Test OTX CLI explain e2e command. @@ -273,13 +280,16 @@ def test_otx_explain_e2e( Returns: None """ - if "tile" in recipe: - pytest.skip("Explain is not supported for tiling yet.") - import cv2 + recipe, task, model_name, tmp_path_train = fxt_trained_model - task = recipe.split("/")[-2] - model_name = recipe.split("/")[-1].split(".")[0] + outputs_dir = tmp_path_train / "outputs" + latest_dir = outputs_dir / ".latest" + ckpt_file = latest_dir / "train" / "best_checkpoint.ckpt" + assert ckpt_file.exists() + + if "tile" in recipe: + pytest.skip("Explain is not supported for tiling yet.") if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]): pytest.skip("Supported only for classification, detection and instance segmentation task.") @@ -302,10 +312,10 @@ def test_otx_explain_e2e( fxt_accelerator, "--seed", "0", - "--deterministic", - "True", "--dump", "True", + "--checkpoint", + str(ckpt_file), *fxt_cli_override_command_per_task[task], ] @@ -322,47 +332,6 @@ def test_otx_explain_e2e( assert sal_map.shape[0] > 0 assert sal_map.shape[1] > 0 - sal_diff_thresh = 3 - reference_sal_vals = { - # Classification - "multi_label_cls_efficientnet_v2_light": ( - np.array([66, 97, 84, 33, 42, 79, 0], dtype=np.uint8), - "Slide6_class_0_saliency_map.png", - ), - "h_label_cls_efficientnet_v2_light": ( - np.array([152, 193, 144, 132, 149, 204, 217], dtype=np.uint8), - "092_class_5_saliency_map.png", - ), - # Detection - "detection_yolox_tiny": ( - np.array([111, 163, 141, 141, 146, 147, 158, 169, 184, 193], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - "detection_ssd_mobilenetv2": ( - np.array([135, 80, 74, 34, 27, 32, 47, 42, 32, 34], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - "detection_atss_mobilenetv2": ( - np.array([22, 62, 64, 0, 27, 60, 59, 53, 37, 45], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - # Instance Segmentation - "instance_segmentation_maskrcnn_efficientnetb2b": ( - np.array([54, 54, 54, 54, 0, 0, 0, 54, 0, 0], dtype=np.uint8), - "Slide3_class_0_saliency_map.png", - ), - } - test_case_name = task + "_" + model_name - if test_case_name in reference_sal_vals: - actual_sal_vals = cv2.imread(str(latest_dir / "saliency_map" / reference_sal_vals[test_case_name][1])) - if test_case_name == "instance_segmentation_maskrcnn_efficientnetb2b": - # Take corner values due to map sparsity of InstSeg - actual_sal_vals = (actual_sal_vals[-10:, -1, -1]).astype(np.uint16) - else: - actual_sal_vals = (actual_sal_vals[:10, 0, 0]).astype(np.uint16) - ref_sal_vals = reference_sal_vals[test_case_name][0] - assert np.max(np.abs(actual_sal_vals - ref_sal_vals) <= sal_diff_thresh) - # @pytest.mark.skipif(len(pytest.RECIPE_OV_LIST) < 1, reason="No OV recipe found.") @pytest.mark.parametrize( diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a6223ab4661..0c7d1ed56be 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -103,7 +103,7 @@ def fxt_rtmdet_tiny_config(fxt_asset_dir: Path) -> MMConfig: # [TODO]: This is a temporary approach. -@pytest.fixture() +@pytest.fixture(scope="module") def fxt_target_dataset_per_task() -> dict: return { "multi_class_cls": "tests/assets/classification_dataset", @@ -123,7 +123,7 @@ def fxt_target_dataset_per_task() -> dict: } -@pytest.fixture() +@pytest.fixture(scope="module") def fxt_cli_override_command_per_task() -> dict: return { "multi_class_cls": [], diff --git a/tests/unit/algo/detection/conftest.py b/tests/unit/algo/detection/conftest.py new file mode 100644 index 00000000000..3d5cd06fbf1 --- /dev/null +++ b/tests/unit/algo/detection/conftest.py @@ -0,0 +1,34 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Test of custom algo modules of OTX Detection task.""" +import pytest +from otx.core.config.data import DataModuleConfig, SubsetConfig +from otx.core.data.module import OTXDataModule +from otx.core.types.task import OTXTaskType +from torchvision.transforms.v2 import Resize + + +@pytest.fixture() +def fxt_data_module(): + return OTXDataModule( + task=OTXTaskType.DETECTION, + config=DataModuleConfig( + data_format="coco_instances", + data_root="tests/assets/car_tree_bug", + train_subset=SubsetConfig( + batch_size=2, + subset_name="train", + transforms=[Resize(320)], + ), + val_subset=SubsetConfig( + batch_size=2, + subset_name="val", + transforms=[Resize(320)], + ), + test_subset=SubsetConfig( + batch_size=2, + subset_name="test", + transforms=[Resize(320)], + ), + ), + ) diff --git a/tests/unit/algo/detection/test_ssd.py b/tests/unit/algo/detection/test_ssd.py index 9a21a1a570d..53466f1806e 100644 --- a/tests/unit/algo/detection/test_ssd.py +++ b/tests/unit/algo/detection/test_ssd.py @@ -2,7 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 """Test of OTX SSD architecture.""" +from pathlib import Path + import pytest +from lightning import Trainer from otx.algo.detection.ssd import SSD @@ -11,16 +14,25 @@ class TestSSD: def fxt_model(self) -> SSD: return SSD(num_classes=3, variant="mobilenetv2") - def test_save_and_load_anchors(self, fxt_model) -> None: - anchor_widths = fxt_model.model.bbox_head.anchor_generator.widths - anchor_heights = fxt_model.model.bbox_head.anchor_generator.heights - state_dict = fxt_model.state_dict() - assert anchor_widths == state_dict["model.model.anchors"]["widths"] - assert anchor_heights == state_dict["model.model.anchors"]["heights"] + @pytest.fixture() + def fxt_checkpoint(self, fxt_model, fxt_data_module, tmpdir, monkeypatch: pytest.MonkeyPatch): + trainer = Trainer(max_steps=0) + + monkeypatch.setattr(trainer.strategy, "_lightning_module", fxt_model) + monkeypatch.setattr(trainer, "datamodule", fxt_data_module) + monkeypatch.setattr(fxt_model, "_trainer", trainer) + fxt_model.setup("fit") + + fxt_model.hparams["ssd_anchors"]["widths"][0][0] = 40 + fxt_model.hparams["ssd_anchors"]["heights"][0][0] = 50 + + checkpoint_path = Path(tmpdir) / "checkpoint.ckpt" + trainer.save_checkpoint(checkpoint_path) + + return checkpoint_path - state_dict["model.model.anchors"]["widths"][0][0] = 40 - state_dict["model.model.anchors"]["heights"][0][0] = 50 + def test_save_and_load_anchors(self, fxt_checkpoint) -> None: + loaded_model = SSD.load_from_checkpoint(checkpoint_path=fxt_checkpoint) - fxt_model.load_state_dict(state_dict) - assert fxt_model.model.bbox_head.anchor_generator.widths[0][0] == 40 - assert fxt_model.model.bbox_head.anchor_generator.heights[0][0] == 50 + assert loaded_model.model.bbox_head.anchor_generator.widths[0][0] == 40 + assert loaded_model.model.bbox_head.anchor_generator.heights[0][0] == 50 diff --git a/tests/unit/core/model/test_base.py b/tests/unit/core/model/test_base.py index 83891686f43..558ae998691 100644 --- a/tests/unit/core/model/test_base.py +++ b/tests/unit/core/model/test_base.py @@ -34,7 +34,9 @@ def test_smart_weight_loading(self, mocker) -> None: "model.head.bias": {"stride": 1, "num_extra_classes": 0}, } current_model.label_info = ["car", "bus", "truck"] - current_model.load_state_dict(prev_state_dict) + current_model.load_state_dict_incrementally( + {"state_dict": prev_state_dict, "label_info": prev_model.label_info}, + ) curr_state_dict = current_model.state_dict() indices = torch.Tensor([0, 2]).to(torch.int32) diff --git a/tests/unit/core/model/test_detection.py b/tests/unit/core/model/test_detection.py index 6c1afb2b3ab..806caf72d02 100644 --- a/tests/unit/core/model/test_detection.py +++ b/tests/unit/core/model/test_detection.py @@ -127,14 +127,3 @@ def test_reset_restore_model_forward(self, otx_model): otx_model._restore_model_forward() assert otx_model.original_model_forward is None assert str(otx_model.model.forward) == str(initial_model_forward) - - def test_export_parameters(self, otx_model): - otx_model.image_size = (1, 64, 64, 3) - otx_model.explain_mode = False - parameters = otx_model._export_parameters - assert isinstance(parameters, dict) - assert "output_names" in parameters - - otx_model.explain_mode = True - parameters = otx_model._export_parameters - assert parameters["output_names"] == ["feature_vector", "saliency_map"] diff --git a/tests/unit/core/model/test_inst_segmentation.py b/tests/unit/core/model/test_inst_segmentation.py index fefd8717c86..8091654d767 100644 --- a/tests/unit/core/model/test_inst_segmentation.py +++ b/tests/unit/core/model/test_inst_segmentation.py @@ -63,14 +63,3 @@ def test_reset_restore_model_forward(self, otx_model): otx_model._restore_model_forward() assert otx_model.original_model_forward is None assert str(otx_model.model.forward) == str(initial_model_forward) - - def test_export_parameters(self, otx_model): - otx_model.image_size = (1, 64, 64, 3) - otx_model.explain_mode = False - parameters = otx_model._export_parameters - assert isinstance(parameters, dict) - assert "output_names" in parameters - - otx_model.explain_mode = True - parameters = otx_model._export_parameters - assert parameters["output_names"] == ["feature_vector", "saliency_map"] diff --git a/tests/unit/engine/test_engine.py b/tests/unit/engine/test_engine.py index e5fe19fc5b8..2aa8449a663 100644 --- a/tests/unit/engine/test_engine.py +++ b/tests/unit/engine/test_engine.py @@ -2,16 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from unittest.mock import create_autospec import pytest from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls from otx.algo.classification.torchvision_model import OTXTVModel from otx.core.config.device import DeviceConfig -from otx.core.model.base import OVModel +from otx.core.model.base import OTXModel, OVModel from otx.core.types.export import OTXExportFormatType +from otx.core.types.label import NullLabelInfo from otx.core.types.precision import OTXPrecisionType from otx.engine import Engine +from pytest_mock import MockerFixture @pytest.fixture() @@ -73,13 +74,30 @@ def test_training_with_override_args(self, fxt_engine, mocker) -> None: assert fxt_engine._cache.args["max_epochs"] == 100 mock_seed_everything.assert_called_once_with(1234, workers=True) - def test_training_with_checkpoint(self, fxt_engine, mocker) -> None: - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + @pytest.mark.parametrize("resume", [True, False]) + def test_training_with_checkpoint(self, fxt_engine, resume: bool, mocker: MockerFixture, tmpdir) -> None: + checkpoint = "path/to/checkpoint.ckpt" + + mock_trainer = mocker.patch("otx.engine.engine.Trainer") + mock_trainer.return_value.default_root_dir = Path(tmpdir) + mock_trainer_fit = mock_trainer.return_value.fit + + mock_torch_load = mocker.patch("otx.engine.engine.torch.load") + mock_load_state_dict_incrementally = mocker.patch.object(fxt_engine.model, "load_state_dict_incrementally") + + trained_checkpoint = Path(tmpdir) / "best.ckpt" + trained_checkpoint.touch() + mock_trainer.return_value.checkpoint_callback.best_model_path = trained_checkpoint + + fxt_engine.train(resume=resume, checkpoint=checkpoint) - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.train() - mock_torch_load.assert_called_once_with("path/to/checkpoint") + if resume: + assert mock_trainer_fit.call_args.kwargs.get("ckpt_path") == checkpoint + else: + assert "ckpt_path" not in mock_trainer_fit.call_args.kwargs + + mock_torch_load.assert_called_once() + mock_load_state_dict_incrementally.assert_called_once() def test_training_with_run_hpo(self, fxt_engine, mocker) -> None: mocker.patch("pathlib.Path.symlink_to") @@ -93,94 +111,97 @@ def test_training_with_run_hpo(self, fxt_engine, mocker) -> None: mock_update_hyper_parameter.assert_called_once_with(fxt_engine, {}) assert mock_fit.call_args[1]["ckpt_path"] == "hpo/best/checkpoint" - def test_training_with_resume(self, fxt_engine, mocker) -> None: - mocker.patch("pathlib.Path.symlink_to") - mock_fit = mocker.patch("otx.engine.engine.Trainer.fit") - - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.train(resume=True) - assert mock_fit.call_args[1]["ckpt_path"] == "path/to/checkpoint" - - def test_testing_after_training(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_test(self, fxt_engine, checkpoint, mocker: MockerFixture) -> None: mock_test = mocker.patch("otx.engine.engine.Trainer.test") - mock_torch_load = mocker.patch("torch.load") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.test() - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_test.assert_called_once() + ext = Path(checkpoint).suffix - fxt_engine.test(checkpoint="path/to/new/checkpoint") - mock_torch_load.assert_called_with("path/to/new/checkpoint") + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - def test_testing_with_ov_model(self, fxt_engine, mocker) -> None: - mock_test = mocker.patch("otx.engine.engine.Trainer.test") - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) - fxt_engine.test(checkpoint="path/to/model.xml") - mock_test.assert_called_once() - mock_torch_load.assert_not_called() + mock_get_ov_model.return_value = mock_model - fxt_engine.model = create_autospec(OVModel) - fxt_engine.test(checkpoint="path/to/model.xml") + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.test(checkpoint=checkpoint) + mock_test.assert_called_once() - def test_prediction_after_training(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a test pipeline, the label information should be same (.*)", + ): + fxt_engine.test(checkpoint=checkpoint) + + @pytest.mark.parametrize("explain", [True, False]) + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_predict(self, fxt_engine, checkpoint, explain, mocker: MockerFixture) -> None: mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_process_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.predict() - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_predict.assert_called_once() + ext = Path(checkpoint).suffix - fxt_engine.predict(checkpoint="path/to/new/checkpoint") - mock_torch_load.assert_called_with("path/to/new/checkpoint") + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - fxt_engine.model = create_autospec(OVModel) - fxt_engine.predict(checkpoint="path/to/model.xml") + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) - def test_prediction_with_ov_model(self, fxt_engine, mocker) -> None: - mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") - mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_get_ov_model.return_value = mock_model - fxt_engine.predict(checkpoint="path/to/model.xml") + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.predict(checkpoint=checkpoint, explain=explain) mock_predict.assert_called_once() - mock_torch_load.assert_not_called() - - def test_prediction_explain_mode(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mock_explain = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") - mock_torch_load = mocker.patch("torch.load") + assert mock_process_saliency_maps.called == explain - # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" - fxt_engine.predict(explain=True) - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_explain.assert_called_once() - mock_predict.assert_called_once() + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a predict pipeline, the label information should be same (.*)", + ): + fxt_engine.predict(checkpoint=checkpoint) def test_exporting(self, fxt_engine, mocker) -> None: with pytest.raises(RuntimeError, match="To make export, checkpoint must be specified."): fxt_engine.export() - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mocker.patch("otx.engine.engine.OTXModel.label_info") mock_export = mocker.patch("otx.engine.engine.OTXModel.export") - mock_torch_load = mocker.patch("torch.load") + + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_load_from_checkpoint.return_value = fxt_engine.model # Fetch Checkpoint - fxt_engine.checkpoint = "path/to/checkpoint" + checkpoint = "path/to/checkpoint.ckpt" + fxt_engine.checkpoint = checkpoint fxt_engine.export() - mock_torch_load.assert_called_once_with("path/to/checkpoint") + mock_load_from_checkpoint.assert_called_once_with(checkpoint_path=checkpoint, map_location="cpu") mock_export.assert_called_once_with( output_dir=Path(fxt_engine.work_dir), base_name="exported_model", @@ -242,32 +263,47 @@ def test_optimizing_model(self, fxt_engine, mocker) -> None: fxt_engine.optimize(export_demo_package=True) mocker_export.assert_called_once() - def test_explain(self, fxt_engine, mocker) -> None: - mocker.patch("otx.engine.engine.OTXModel.load_state_dict") - mock_process_explain = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") - - mock_torch_load = mocker.patch("torch.load") + @pytest.mark.parametrize("dump", [True, False]) + @pytest.mark.parametrize( + "checkpoint", + [ + "path/to/checkpoint.ckpt", + "path/to/checkpoint.xml", + ], + ) + def test_explain(self, fxt_engine, checkpoint, dump, mocker) -> None: mock_predict = mocker.patch("otx.engine.engine.Trainer.predict") + _ = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") + mock_get_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") + mock_load_from_checkpoint = mocker.patch.object(fxt_engine.model.__class__, "load_from_checkpoint") + mock_process_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity") + mock_dump_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.dump_saliency_maps") - fxt_engine.explain(checkpoint="path/to/checkpoint") - mock_torch_load.assert_called_once_with("path/to/checkpoint") - mock_predict.assert_called_once() - mock_process_explain.assert_called_once() + ext = Path(checkpoint).suffix - mock_dump_saliency_maps = mocker.patch("otx.algo.utils.xai_utils.dump_saliency_maps") - fxt_engine.explain(checkpoint="path/to/checkpoint", dump=True) - mock_torch_load.assert_called_with("path/to/checkpoint") - mock_predict.assert_called() - mock_process_explain.assert_called() - mock_dump_saliency_maps.assert_called_once() + if ext == ".ckpt": + mock_model = mocker.create_autospec(OTXModel) - mock_ov_pipeline = mocker.patch("otx.engine.engine.AutoConfigurator.update_ov_subset_pipeline") - mock_ov_model = mocker.patch("otx.engine.engine.AutoConfigurator.get_ov_model") - fxt_engine.explain(checkpoint="path/to/model.xml") - mock_predict.assert_called() - mock_process_explain.assert_called() - mock_ov_model.assert_called_once() - mock_ov_pipeline.assert_called_once() + mock_load_from_checkpoint.return_value = mock_model + else: + mock_model = mocker.create_autospec(OVModel) + + mock_get_ov_model.return_value = mock_model + + # Correct label_info from the checkpoint + mock_model.label_info = fxt_engine.datamodule.label_info + fxt_engine.explain(checkpoint=checkpoint, dump=dump) + mock_predict.assert_called_once() + mock_process_saliency_maps.assert_called_once() + assert mock_dump_saliency_maps.called == dump + + mock_model.label_info = NullLabelInfo() + # Incorrect label_info from the checkpoint + with pytest.raises( + ValueError, + match="To launch a explain pipeline, the label information should be same (.*)", + ): + fxt_engine.explain(checkpoint=checkpoint) def test_from_config_with_model_name(self, tmp_path) -> None: model_name = "efficientnet_b0_light"