Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧹 Refactor Anomaly Models #4102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 5 additions & 100 deletions src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,17 @@

from typing import TYPE_CHECKING, Literal

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image import Padim as AnomalibPadim

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
from otx.core.types.label import LabelInfoTypes


class Padim(OTXAnomaly, AnomalibPadim):
class Padim(AnomalyMixin, AnomalibPadim, OTXAnomaly):
"""OTX Padim model.

Args:
Expand Down Expand Up @@ -55,100 +49,11 @@ def __init__(
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
) -> None:
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
AnomalibPadim.__init__(
self,
self.input_size = input_size
self.task = OTXTaskType(task)
super().__init__(
backbone=backbone,
layers=layers,
pre_trained=pre_trained,
n_features=n_features,
)
self.task = task
self.input_size = input_size

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
return

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def on_train_epoch_end(self) -> None:
"""Callback triggered when the training epoch ends."""
return AnomalibPadim.on_train_epoch_end(self)

def on_validation_start(self) -> None:
"""Callback triggered when the validation starts."""
return AnomalibPadim.on_validation_start(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibPadim.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibPadim.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibPadim.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibPadim.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibPadim.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
102 changes: 5 additions & 97 deletions src/otx/algo/anomaly/stfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,17 @@

from typing import TYPE_CHECKING, Literal, Sequence

from anomalib.callbacks.normalization.min_max_normalization import _MinMaxNormalizationCallback
from anomalib.callbacks.post_processor import _PostProcessorCallback
from anomalib.models.image.stfpm import Stfpm as AnomalibStfpm

from otx.core.model.anomaly import OTXAnomaly
from otx.core.model.anomaly import AnomalyMixin, OTXAnomaly
from otx.core.types.label import AnomalyLabelInfo
from otx.core.types.task import OTXTaskType

if TYPE_CHECKING:
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.optim.optimizer import Optimizer

from otx.core.model.anomaly import AnomalyModelInputs, AnomalyModelOutputs
from otx.core.types.label import LabelInfoTypes


class Stfpm(OTXAnomaly, AnomalibStfpm):
class Stfpm(AnomalyMixin, AnomalibStfpm, OTXAnomaly):
"""OTX STFPM model.

Args:
Expand All @@ -52,95 +46,9 @@ def __init__(
input_size: tuple[int, int] = (256, 256),
**kwargs,
) -> None:
OTXAnomaly.__init__(self, label_info=label_info, input_size=input_size)
AnomalibStfpm.__init__(
self,
self.input_size = input_size
self.task = OTXTaskType(task)
super().__init__(
backbone=backbone,
layers=layers,
)
self.task = task
self.input_size = input_size

@property
def trainable_model(self) -> str:
"""Used by configure optimizer."""
return "student_model"

def configure_metric(self) -> None:
"""This does not follow OTX metric configuration."""
return

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""STFPM does not follow OTX optimizer configuration."""
return AnomalibStfpm.configure_optimizers(self)

def on_validation_epoch_start(self) -> None:
"""Callback triggered when the validation epoch starts."""
AnomalibStfpm.on_validation_epoch_start(self)

def on_test_epoch_start(self) -> None:
"""Callback triggered when the test epoch starts."""
AnomalibStfpm.on_test_epoch_start(self)

def on_validation_epoch_end(self) -> None:
"""Callback triggered when the validation epoch ends."""
AnomalibStfpm.on_validation_epoch_end(self)

def on_test_epoch_end(self) -> None:
"""Callback triggered when the test epoch ends."""
AnomalibStfpm.on_test_epoch_end(self)

def training_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call training step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.training_step(self, inputs, batch_idx) # type: ignore[misc]

def validation_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
) -> STEP_OUTPUT:
"""Call validation step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.validation_step(self, inputs, batch_idx) # type: ignore[misc]

def test_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.test_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def predict_step(
self,
inputs: AnomalyModelInputs,
batch_idx: int = 0,
**kwargs,
) -> STEP_OUTPUT:
"""Call test step of the anomalib model."""
if not isinstance(inputs, dict):
inputs = self._customize_inputs(inputs)
return AnomalibStfpm.predict_step(self, inputs, batch_idx, **kwargs) # type: ignore[misc]

def forward(
self,
inputs: AnomalyModelInputs,
) -> AnomalyModelOutputs:
"""Wrap forward method of the Anomalib model."""
outputs = self.validation_step(inputs)
# TODO(Ashwin): update forward implementation to comply with other OTX models
_PostProcessorCallback._post_process(outputs) # noqa: SLF001
_PostProcessorCallback._compute_scores_and_labels(self, outputs) # noqa: SLF001
_MinMaxNormalizationCallback._normalize_batch(outputs, self) # noqa: SLF001

return self._customize_outputs(outputs=outputs, inputs=inputs)
Loading
Loading