Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor checkpiont logic #3302

Merged
merged 10 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ def __init__(
)
self.image_size = (1, 3, 864, 864)
self.tile_image_size = self.image_size
self._register_load_state_dict_pre_hook(self._set_anchors_hook)

def _create_model(self) -> nn.Module:
from mmdet.models.data_preprocessors import (
Expand Down Expand Up @@ -410,6 +409,10 @@ def setup(self, stage: str) -> None:
anchor_generator.widths = new_anchors[0]
anchor_generator.heights = new_anchors[1]
anchor_generator.gen_base_anchors()
self.hparams["ssd_anchors"] = {
"heights": anchor_generator.heights,
"widths": anchor_generator.widths,
}

def _get_new_anchors(self, dataset: OTXDataset, anchor_generator: SSDAnchorGeneratorClustered) -> tuple | None:
"""Get new anchors for SSD from OTXDataset."""
Expand Down Expand Up @@ -521,19 +524,6 @@ def get_classification_layers(
classification_layers[prefix + key] = {"use_bg": use_bg, "num_anchors": num_anchors}
return classification_layers

def state_dict(self, *args, **kwargs) -> dict[str, Any]:
"""Return state dictionary of model entity with anchor information.

Returns:
A dictionary containing SSD state.

"""
state_dict = super().state_dict(*args, **kwargs)
anchor_generator = self.model.bbox_head.anchor_generator
anchors = {"heights": anchor_generator.heights, "widths": anchor_generator.widths}
state_dict["model.model.anchors"] = anchors
return state_dict

def load_state_dict_pre_hook(self, state_dict: dict[str, torch.Tensor], prefix: str, *args, **kwargs) -> None:
"""Modify input state_dict according to class name matching before weight loading."""
model2ckpt = self.map_class_names(self.model_classes, self.ckpt_classes)
Expand Down Expand Up @@ -588,15 +578,16 @@ def _exporter(self) -> OTXModelExporter:
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def _set_anchors_hook(self, state_dict: dict[str, Any], *args, **kwargs) -> None:
"""Pre hook for pop anchor statistics from checkpoint state_dict."""
anchors = state_dict.pop("model.model.anchors", None)
if anchors is not None:
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on load checkpoint."""
if (hparams := checkpoint.get("hyper_parameters")) and (anchors := hparams.get("ssd_anchors", None)):
anchor_generator = self.model.bbox_head.anchor_generator
anchor_generator.widths = anchors["widths"]
anchor_generator.heights = anchors["heights"]
anchor_generator.gen_base_anchors()

return super().on_load_checkpoint(checkpoint)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
return OTXv1Helper.load_ssd_ckpt(state_dict, add_prefix)
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/anomaly/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from otx.core.data.entity.base import ImageInfo
from otx.core.data.mem_cache import NULL_MEM_CACHE_HANDLER, MemCacheHandlerBase
from otx.core.types.image import ImageColorChannel
from otx.core.types.label import LabelInfo
from otx.core.types.task import OTXTaskType


Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
image_color_channel,
stack_images,
)
self.label_info = LabelInfo(label_names=["Normal", "Anomaly"], label_groups=[["Normal", "Anomaly"]])

def _get_item_impl(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingDataEntity,
)
from otx.core.types.label import NullLabelInfo
from otx.core.utils.mask_util import polygon_to_bitmap

from .base import OTXDataset, Transforms
Expand Down Expand Up @@ -61,6 +62,8 @@ def __init__(
# if using only point prompt
self.prob = 0.0

self.label_info = NullLabelInfo()

def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(dmImage)
Expand Down Expand Up @@ -189,6 +192,8 @@ def __init__(
# if using only point prompt
self.prob = 0.0

self.label_info = NullLabelInfo()

def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None:
item = self.dm_subset.get(id=self.ids[index], subset=self.dm_subset.name)
img = item.media_as(dmImage)
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/metrics/fmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(
self._f_measure_per_nms: dict | None = None
self._best_confidence_threshold: float | None = None
self._best_nms_threshold: float | None = None
self._f_measure = 0.0
self._f_measure = float("-inf")

self.reset()

Expand Down
38 changes: 17 additions & 21 deletions src/otx/core/model/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from collections import OrderedDict

from anomalib.metrics import AnomalibMetricCollection
from anomalib.metrics.threshold import BaseThreshold
from lightning.pytorch import Trainer
Expand Down Expand Up @@ -159,6 +157,22 @@ def __init__(self) -> None:
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on saving checkpoint."""
super().on_save_checkpoint(checkpoint) # type: ignore[misc]

attrs = ["_task_type", "_input_size", "mean_values", "scale_values", "image_threshold", "pixel_threshold"]

checkpoint["anomaly"] = {key: getattr(self, key, None) for key in attrs}

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on loading checkpoint."""
super().on_load_checkpoint(checkpoint) # type: ignore[misc]

if anomaly_attrs := checkpoint.get("anomaly", None):
for key, value in anomaly_attrs.items():
setattr(self, key, value)

@property
def input_size(self) -> tuple[int, int]:
"""Returns the input size of the model.
Expand Down Expand Up @@ -238,7 +252,7 @@ def trainable_model(self) -> str | None:
def setup(self, stage: str | None = None) -> None:
"""Setup the model."""
super().setup(stage) # type: ignore[misc]
if hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"):
if stage == "fit" and hasattr(self.trainer, "datamodule") and hasattr(self.trainer.datamodule, "config"):
if hasattr(self.trainer.datamodule.config, "test_subset"):
self._extract_mean_scale_from_transforms(self.trainer.datamodule.config.test_subset.transforms)
elif hasattr(self.trainer.datamodule.config, "val_subset"):
Expand Down Expand Up @@ -327,24 +341,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[torch.
return optimizer(params=params)
return super().configure_optimizers() # type: ignore[misc]

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.

Returns:
A dictionary containing datamodule state.

"""
state_dict = super().state_dict() # type: ignore[misc]
# This is defined in OTXModel
state_dict["label_info"] = self.label_info # type: ignore[attr-defined]
return state_dict

def load_state_dict(self, ckpt: OrderedDict[str, Any], *args, **kwargs) -> None:
"""Pass the checkpoint to the anomaly model."""
ckpt = ckpt.get("state_dict", ckpt)
ckpt.pop("label_info", None) # [TODO](ashwinvaidya17): Revisit this method when OTXModel is the lightning model
return super().load_state_dict(ckpt, *args, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
Expand Down
89 changes: 64 additions & 25 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from torch.optim.sgd import SGD
from torchmetrics import Metric, MetricCollection

from otx import __version__
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import (
OTXBatchLossEntity,
T_OTXBatchDataEntity,
Expand Down Expand Up @@ -113,6 +115,8 @@ def __init__(
self.torch_compile = torch_compile
self._explain_mode = False

self._tile_config: TileConfig | None = None

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False, ignore=["model", "optimizer", "scheduler", "metric"])
Expand Down Expand Up @@ -336,16 +340,54 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa

self.log(log_metric_name, value, sync_dist=True, prog_bar=True)

def state_dict(self) -> dict[str, Any]:
"""Return state dictionary of model entity with meta information.
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on saving checkpoint."""
super().on_save_checkpoint(checkpoint)

Returns:
A dictionary containing datamodule state.
checkpoint["label_info"] = self.label_info
checkpoint["otx_version"] = __version__

"""
state_dict = super().state_dict()
state_dict["label_info"] = self.label_info
return state_dict
if self._tile_config:
checkpoint["tile_config"] = self._tile_config

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on loading checkpoint."""
super().on_load_checkpoint(checkpoint)

if ckpt_label_info := checkpoint.get("label_info", None):
self._label_info = ckpt_label_info

if ckpt_tile_config := checkpoint.get("tile_config", None):
self._tile_config = ckpt_tile_config

def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
"""Load state dict incrementally."""
ckpt_label_info: LabelInfo | None = ckpt.get("label_info", None)

if ckpt_label_info is None:
msg = "Checkpoint should have `label_info`."
raise ValueError(msg, ckpt_label_info)

if ckpt_label_info != self.label_info:
msg = (
"Load model state dictionary incrementally: "
f"Label info from checkpoint: {ckpt_label_info} -> "
f"Label info from training data: {self.label_info}"
)
logger.info(msg)
self.register_load_state_dict_pre_hook(
self.label_info.label_names,
ckpt_label_info.label_names,
)

# Model weights
state_dict: dict[str, Any] = ckpt.get("state_dict", None)

if ckpt_label_info is None:
msg = "Checkpoint should have `state_dict`."
raise ValueError(msg, ckpt_label_info)

self.load_state_dict(state_dict, *args, **kwargs)

def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
"""Load state dictionary from checkpoint state dictionary.
Expand All @@ -364,23 +406,6 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
else:
state_dict = ckpt

ckpt_label_info = state_dict.pop("label_info", None)

if ckpt_label_info and self.label_info is None:
msg = (
"`state_dict` to load has `label_info`, but the current model has no `label_info`. "
"It is recommended to set proper `label_info` for the incremental learning case."
)
warnings.warn(msg, stacklevel=2)
if ckpt_label_info and self.label_info and ckpt_label_info != self.label_info:
logger.warning(
f"Data classes from checkpoint: {ckpt_label_info.label_names} -> "
f"Data classes from training data: {self.label_info.label_names}",
)
self.register_load_state_dict_pre_hook(
self.label_info.label_names,
ckpt_label_info.label_names,
)
return super().load_state_dict(state_dict, *args, **kwargs)

def load_from_otx_v1_ckpt(self, ckpt: dict[str, Any]) -> dict:
Expand Down Expand Up @@ -698,6 +723,20 @@ def patch_optimizer_and_scheduler_for_hpo(self) -> None:
if not isinstance(self.scheduler_callable, PicklableLRSchedulerCallable):
self.scheduler_callable = PicklableLRSchedulerCallable(self.scheduler_callable)

@property
def tile_config(self) -> TileConfig:
"""Get tiling configurations."""
if self._tile_config is None:
msg = "This task type does not support tiling."
raise RuntimeError(msg)

return self._tile_config

@tile_config.setter
def tile_config(self, tile_config: TileConfig) -> None:
"""Set tiling configurations."""
self._tile_config = tile_config


class OVModel(OTXModel, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the OpenVINO model.
Expand Down
31 changes: 23 additions & 8 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from openvino.model_api.tilers import DetectionTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import TileBatchDetDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
from otx.core.schedulers import LRSchedulerListCallable
Expand All @@ -37,8 +36,6 @@
from torch import nn
from torchmetrics import Metric

from otx.core.metrics import MetricCallable


class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]):
"""Base class for the detection models used in OTX."""
Expand All @@ -58,7 +55,7 @@ def __init__(
metric=metric,
torch_compile=torch_compile,
)
self.tile_config = TileConfig()
self._tile_config = TileConfig()

def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity:
"""Unpack detection tiles.
Expand Down Expand Up @@ -170,14 +167,32 @@ def _log_metrics(self, meter: Metric, key: Literal["val", "test"], **compute_kwa
class ExplainableOTXDetModel(OTXDetectionModel):
"""OTX detection model which can attach a XAI (Explainable AI) branch."""

def __init__(
self,
num_classes: int,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MeanAPCallable,
torch_compile: bool = False,
) -> None:
super().__init__(
num_classes=num_classes,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
)

from otx.algo.explain.explain_algo import get_feature_vector

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

def forward_explain(
self,
inputs: DetBatchDataEntity,
) -> DetBatchPredEntity:
"""Model forward function."""
self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

# If customize_inputs is overridden
outputs = (
self._forward_explain_detection(self.model, **self._customize_inputs(inputs))
Expand Down
Loading
Loading