Skip to content

Commit

Permalink
Refactor checkpiont logic (#3302)
Browse files Browse the repository at this point in the history
* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Remove commented line

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix test_otx_e2e

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix test

Signed-off-by: Kim, Vinnam <[email protected]>

* Remove deterministic from test_otx_explain_e2e

Signed-off-by: Kim, Vinnam <[email protected]>

* Revert scope="module"

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Apr 15, 2024
1 parent a03c294 commit fb69fcb
Show file tree
Hide file tree
Showing 19 changed files with 469 additions and 346 deletions.
27 changes: 9 additions & 18 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def __init__(
)
self.image_size = (1, 3, 864, 864)
self.tile_image_size = self.image_size
self._register_load_state_dict_pre_hook(self._set_anchors_hook)

def _create_model(self) -> nn.Module:
from mmdet.models.data_preprocessors import (
Expand Down Expand Up @@ -410,6 +409,10 @@ def setup(self, stage: str) -> None:
anchor_generator.widths = new_anchors[0]
anchor_generator.heights = new_anchors[1]
anchor_generator.gen_base_anchors()
self.hparams["ssd_anchors"] = {
"heights": anchor_generator.heights,
"widths": anchor_generator.widths,
}

def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGeneratorClustered) -> tuple | None:
"""Get new anchors for SSD from OTXDataset."""
Expand Down Expand Up @@ -521,19 +524,6 @@ def get_classification_layers(
classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors}
return classification_layers

def state_dict(self, *args, **kwargs) -> dict[str, Any]:
"""Return state dictionary of model entity with anchor information.
Returns:
A dictionary containing SSD state.
"""
state_dict = super().state_dict(*args, **kwargs)
anchor_generator = self.model.bbox_head.anchor_generator
anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths}
state_dict["model.model.anchors"] = anchors
return state_dict

def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None:
"""Modify input state_dict according to class name matching before weight loading."""
model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes)
Expand Down Expand Up @@ -588,15 +578,16 @@ def _exporter(self) -> OTXModelExporter:
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None:
"""Pre hook for pop anchor statistics from checkpoint state_dict."""
anchors = state_dict.pop("model.model.anchors", None)
if anchors is not None:
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on load checkpoint."""
if (hparams := checkpoint.get("hyper_parameters")) and (anchors := hparams.get("ssd_anchors", None)):
anchor_generator = self.model.bbox_head.anchor_generator
anchor_generator.widths = anchors["widths"]
anchor_generator.heights = anchors["heights"]
anchor_generator.gen_base_anchors()

return super().on_load_checkpoint(checkpoint)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix)
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/anomaly/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from otx.core.data.entity.base import ImageInfo
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
from otx.core.types.image import ImageColorChannel
from otx.core.types.label import LabelInfo
from otx.core.types.task import OTXTaskType


Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
image_color_channel,
stack_images,
)
self.label_info = LabelInfo(label_names=["Normal", "Anomaly"], label_groups=[["Normal", "Anomaly"]])

def _get_item_impl(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingDataEntity,
)
from otx.core.types.label import NullLabelInfo
from otx.core.utils.mask_util import polygon_to_bitmap

from .base import OTXDataset, Transforms
Expand Down Expand Up @@ -61,6 +62,8 @@ def __init__(
# if using only point prompt
self.prob = 0.0

self.label_info = NullLabelInfo()

def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(dmImage)
Expand Down Expand Up @@ -189,6 +192,8 @@ def __init__(
# if using only point prompt
self.prob = 0.0

self.label_info = NullLabelInfo()

def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(dmImage)
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/metrics/fmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(
self._f_measure_per_nms: dict | None = None
self._best_confidence_threshold: float | None = None
self._best_nms_threshold: float | None = None
self._f_measure = 0.0
self._f_measure = float("-inf")

self.reset()

Expand Down
38 changes: 17 additions & 21 deletions src/otx/core/model/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from collections import OrderedDict

from anomalib.metrics import AnomalibMetricCollection
from anomalib.metrics.threshold import BaseThreshold
from lightning.pytorch import Trainer
Expand Down Expand Up @@ -159,6 +157,22 @@ def __init__(self) -> None:
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on saving checkpoint."""
super().on_save_checkpoint(checkpoint) # type: ignore[misc]

attrs = ["_task_type", "_input_size", "mean_values", "scale_values", "image_threshold", "pixel_threshold"]

checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs}

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on loading checkpoint."""
super().on_load_checkpoint(checkpoint) # type: ignore[misc]

if anomaly_attrs := checkpoint.get("anomaly", None):
for key, value in anomaly_attrs.items():
setattr(self, key, value)

@property
def input_size(self) -> tuple[int, int]:
"""Returns the input size of the model.
Expand Down Expand Up @@ -238,7 +252,7 @@ def trainable_model(self) -> str | None:
def setup(self, stage: str | None = None) -> None:
"""Setup the model."""
super().setup(stage) # type: ignore[misc]
if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"):
if stage == "fit" and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"):
if hasattr(self.trainer.datamodule.config, "test_subset"):
self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.test_subset.transforms)
elif hasattr(self.trainer.datamodule.config, "val_subset"):
Expand Down Expand Up @@ -327,24 +341,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.
return optimizer(params=params)
return super().configure_optimizers() # type: ignore[misc]

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.
Returns:
A dictionary containing datamodule state.
"""
state_dict = super().state_dict() # type: ignore[misc]
# This is defined in OTXModel
state_dict["label_info"] = self.label_info # type: ignore[attr-defined]
return state_dict

def load_state_dict(self, ckpt: OrderedDict[str, Any], *args, **kwargs) -> None:
"""Pass the checkpoint to the anomaly model."""
ckpt = ckpt.get("state_dict", ckpt)
ckpt.pop("label_info", None) # [TODO](ashwinvaidya17): Revisit this method when OTXModel is the lightning model
return super().load_state_dict(ckpt, *args, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
Expand Down
89 changes: 64 additions & 25 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from torch.optim.sgd import SGD
from torchmetrics import Metric, MetricCollection

from otx import __version__
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import (
OTXBatchLossEntity,
T_OTXBatchDataEntity,
Expand Down Expand Up @@ -113,6 +115,8 @@ def __init__(
self.torch_compile = torch_compile
self._explain_mode = False

self._tile_config: TileConfig | None = None

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False, ignore=["model", "optimizer", "scheduler", "metric"])
Expand Down Expand Up @@ -336,16 +340,54 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa

self.log(log_metric_name, value, sync_dist=True, prog_bar=True)

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on saving checkpoint."""
super().on_save_checkpoint(checkpoint)

Returns:
A dictionary containing datamodule state.
checkpoint["label_info"] = self.label_info
checkpoint["otx_version"] = __version__

"""
state_dict = super().state_dict()
state_dict["label_info"] = self.label_info
return state_dict
if self._tile_config:
checkpoint["tile_config"] = self._tile_config

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on loading checkpoint."""
super().on_load_checkpoint(checkpoint)

if ckpt_label_info := checkpoint.get("label_info", None):
self._label_info = ckpt_label_info

if ckpt_tile_config := checkpoint.get("tile_config", None):
self._tile_config = ckpt_tile_config

def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
"""Load state dict incrementally."""
ckpt_label_info: LabelInfo | None = ckpt.get("label_info", None)

if ckpt_label_info is None:
msg = "Checkpoint should have `label_info`."
raise ValueError(msg, ckpt_label_info)

if ckpt_label_info != self.label_info:
msg = (
"Load model state dictionary incrementally: "
f"Label info from checkpoint: {ckpt_label_info} -> "
f"Label info from training data: {self.label_info}"
)
logger.info(msg)
self.register_load_state_dict_pre_hook(
self.label_info.label_names,
ckpt_label_info.label_names,
)

# Model weights
state_dict: dict[str, Any] = ckpt.get("state_dict", None)

if ckpt_label_info is None:
msg = "Checkpoint should have `state_dict`."
raise ValueError(msg, ckpt_label_info)

self.load_state_dict(state_dict, *args, **kwargs)

def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
"""Load state dictionary from checkpoint state dictionary.
Expand All @@ -364,23 +406,6 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
else:
state_dict = ckpt

ckpt_label_info = state_dict.pop("label_info", None)

if ckpt_label_info and self.label_info is None:
msg = (
"`state_dict` to load has `label_info`, but the current model has no `label_info`. "
"It is recommended to set proper `label_info` for the incremental learning case."
)
warnings.warn(msg, stacklevel=2)
if ckpt_label_info and self.label_info and ckpt_label_info != self.label_info:
logger.warning(
f"Data classes from checkpoint: {ckpt_label_info.label_names} -> "
f"Data classes from training data: {self.label_info.label_names}",
)
self.register_load_state_dict_pre_hook(
self.label_info.label_names,
ckpt_label_info.label_names,
)
return super().load_state_dict(state_dict, *args, **kwargs)

def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict:
Expand Down Expand Up @@ -698,6 +723,20 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None:
if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable):
self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable)

@property
def tile_config(self) -> TileConfig:
"""Get tiling configurations."""
if self._tile_config is None:
msg = "This task type does not support tiling."
raise RuntimeError(msg)

return self._tile_config

@tile_config.setter
def tile_config(self, tile_config: TileConfig) -> None:
"""Set tiling configurations."""
self._tile_config = tile_config


class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the OpenVINO model.
Expand Down
31 changes: 23 additions & 8 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from openvino.model_api.tilers import DetectionTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import TileBatchDetDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
Expand All @@ -37,8 +36,6 @@
from torch import nn
from torchmetrics import Metric

from otx.core.metrics import MetricCallable


class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]):
"""Base class for the detection models used in OTX."""
Expand All @@ -58,7 +55,7 @@ def __init__(
metric=metric,
torch_compile=torch_compile,
)
self.tile_config = TileConfig()
self._tile_config = TileConfig()

def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity:
"""Unpack detection tiles.
Expand Down Expand Up @@ -170,14 +167,32 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
class ExplainableOTXDetModel(OTXDetectionModel):
"""OTX detection model which can attach a XAI (Explainable AI) branch."""

def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
super().__init__(
num_classes=num_classes,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

from otx.algo.explain.explain_algo import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

def forward_explain(
self,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity:
"""Model forward function."""
self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

# If customize_inputs is overridden
outputs = (
self._forward_explain_detection(self.model, **self._customize_inputs(inputs))
Expand Down
Loading

0 comments on commit fb69fcb

Please sign in to comment.