Skip to content

Commit

Permalink
🧹 Refactor Anomaly Models (#4102)
Browse files Browse the repository at this point in the history
* refactor

Signed-off-by: Ashwin Vaidya <[email protected]>

* Fix tests

Signed-off-by: Ashwin Vaidya <[email protected]>

---------

Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 authored Nov 19, 2024
1 parent 0bb3f2e commit fc221e8
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 228 deletions.
105 changes: 5 additions & 100 deletions src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,17 @@

from typing import TYPE_CHECKING, Literal

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image import Padim as AnomalibPadim

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
from otx.core.types.label import LabelInfoTypes


class Padim(OTXAnomaly, AnomalibPadim):
class Padim(AnomalyMixin, AnomalibPadim, OTXAnomaly):
"""OTX Padim model.
Args:
Expand Down Expand Up @@ -55,100 +49,11 @@ def __init__(
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
) -> None:
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
AnomalibPadim.__init__(
self,
self.input_size = input_size
self.task = OTXTaskType(task)
super().__init__(
backbone=backbone,
layers=layers,
pre_trained=pre_trained,
n_features=n_features,
)
self.task = task
self.input_size = input_size

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
return

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def on_train_epoch_end(self) -> None:
"""Callback triggered when the training epoch ends."""
return AnomalibPadim.on_train_epoch_end(self)

def on_validation_start(self) -> None:
"""Callback triggered when the validation starts."""
return AnomalibPadim.on_validation_start(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibPadim.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibPadim.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibPadim.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibPadim.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
102 changes: 5 additions & 97 deletions src/otx/algo/anomaly/stfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,17 @@

from typing import TYPE_CHECKING, Literal, Sequence

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
from otx.core.types.label import LabelInfoTypes


class Stfpm(OTXAnomaly, AnomalibStfpm):
class Stfpm(AnomalyMixin, AnomalibStfpm, OTXAnomaly):
"""OTX STFPM model.
Args:
Expand All @@ -52,95 +46,9 @@ def __init__(
input_size: tuple[int, int] = (256, 256),
**kwargs,
) -> None:
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
AnomalibStfpm.__init__(
self,
self.input_size = input_size
self.task = OTXTaskType(task)
super().__init__(
backbone=backbone,
layers=layers,
)
self.task = task
self.input_size = input_size

@property
def trainable_model(self) -> str:
"""Used by configure optimizer."""
return "student_model"

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""STFPM does not follow OTX optimizer configuration."""
return AnomalibStfpm.configure_optimizers(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibStfpm.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibStfpm.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibStfpm.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibStfpm.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
Loading

0 comments on commit fc221e8

Please sign in to comment.