Skip to content

Commit

Permalink
if metric is loss, change hpo model to min
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Apr 12, 2024
1 parent 8f23823 commit edb4389
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 35 deletions.
70 changes: 46 additions & 24 deletions src/otx/engine/hpo/hpo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from otx.utils.utils import get_decimal_point, get_using_dot_delimited_key, remove_matched_files

from .hpo_trial import run_hpo_trial
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir
from .utils import find_trial_file, get_best_hpo_weight, get_callable_args_name, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import Callback
from lightning.pytorch.cli import OptimizerCallable

from otx.engine.engine import Engine
Expand All @@ -42,18 +43,20 @@
def execute_hpo(
engine: Engine,
max_epochs: int,
hpo_config: HpoConfig | None = None,
hpo_config: HpoConfig,
progress_update_callback: Callable[[int | float], None] | None = None,
callbacks: list[Callback] | Callback | None = None,
**train_args,
) -> tuple[dict[str, Any] | None, Path | None]:
"""Execute HPO.
Args:
engine (Engine): engine instnace.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig | None, optional): Configuration for HPO.
hpo_config (HpoConfig): Configuration for HPO.
progress_update_callback (Callable[[int | float], None] | None, optional):
callback to update progress. If it's given, it's called with progress every second. Defaults to None.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
Returns:
tuple[dict[str, Any] | None, Path | None]:
Expand All @@ -72,10 +75,11 @@ def execute_hpo(
hpo_workdir = Path(engine.work_dir) / "hpo"
hpo_workdir.mkdir(exist_ok=True)
hpo_configurator = HPOConfigurator(
engine,
max_epochs,
hpo_workdir,
hpo_config,
engine=engine,
max_epochs=max_epochs,
hpo_config=hpo_config,
hpo_workdir=hpo_workdir,
callbacks=callbacks,
)
if (hpo_algo := hpo_configurator.get_hpo_algo()) is None:
logger.warning("HPO is skipped.")
Expand All @@ -91,7 +95,8 @@ def execute_hpo(
hpo_workdir=hpo_workdir,
engine=engine,
max_epochs=max_epochs,
metric_name=None if hpo_config is None else hpo_config.metric_name,
callbacks=callbacks,
metric_name=hpo_config.metric_name,
**_adjust_train_args(train_args),
),
"gpu" if torch.cuda.is_available() else "cpu",
Expand All @@ -118,21 +123,24 @@ class HPOConfigurator:
Args:
engine (Engine): engine instance.
max_epoch (int): max epochs to train.
max_epochs (int): max epochs to train.
hpo_config (HpoConfig): Configuration for HPO.
hpo_workdir (Path | None, optional): HPO work directory. Defaults to None.
hpo_config (HpoConfig | None, optional): Configuration for HPO.
callbacks (list[Callback] | Callback | None, optional): callbacks used during training. Defaults to None.
"""

def __init__(
self,
engine: Engine,
max_epoch: int,
max_epochs: int,
hpo_config: HpoConfig,
hpo_workdir: Path | None = None,
hpo_config: HpoConfig | None = None,
callbacks: list[Callback] | Callback | None = None,
) -> None:
self._engine = engine
self._max_epoch = max_epoch
self._max_epochs = max_epochs
self._hpo_workdir = hpo_workdir if hpo_workdir is not None else Path(engine.work_dir) / "hpo"
self._callbacks = callbacks
self.hpo_config: dict[str, Any] = hpo_config # type: ignore[assignment]

@property
Expand All @@ -141,26 +149,40 @@ def hpo_config(self) -> dict[str, Any]:
return self._hpo_config

@hpo_config.setter
def hpo_config(self, hpo_config: HpoConfig | None) -> None:
def hpo_config(self, hpo_config: HpoConfig) -> None:
train_dataset_size = len(
self._engine.datamodule.subsets[self._engine.datamodule.config.train_subset.subset_name],
)

if hpo_config.metric_name is None:
if self._callbacks is None:
msg = (
"HPOConfigurator can't find the metric because callback doesn't exist. "
"Please set hpo_config.metric_name."
)
raise RuntimeError(msg)
hpo_config.metric_name = get_metric(self._callbacks)

if "loss" in hpo_config.metric_name and hpo_config.mode == "max":
logger.warning(
f"Because metric for HPO is {hpo_config.metric_name}, hpo_config.mode is changed from max to min.",
)
hpo_config.mode = "min"

self._hpo_config: dict[str, Any] = { # default setting
"save_path": str(self._hpo_workdir),
"num_full_iterations": self._max_epoch,
"num_full_iterations": self._max_epochs,
"full_dataset_size": train_dataset_size,
}

if hpo_config is not None:
hb_arg_names = get_callable_args_name(HyperBand)
self._hpo_config.update(
{
key: val
for key, val in dataclasses.asdict(hpo_config).items()
if val is not None and key in hb_arg_names
},
)
hb_arg_names = get_callable_args_name(HyperBand)
self._hpo_config.update(
{
key: val
for key, val in dataclasses.asdict(hpo_config).items()
if val is not None and key in hb_arg_names
},
)

if "search_space" not in self._hpo_config:
self._hpo_config["search_space"] = self._get_default_search_space()
Expand Down
12 changes: 2 additions & 10 deletions src/otx/engine/hpo/hpo_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from otx.hpo import TrialStatus
from otx.utils.utils import find_file_recursively, remove_matched_files, set_using_dot_delimited_key

from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir
from .utils import find_trial_file, get_best_hpo_weight, get_hpo_weight_dir, get_metric

if TYPE_CHECKING:
from lightning import LightningModule, Trainer
Expand Down Expand Up @@ -105,18 +105,10 @@ def _register_hpo_callback(
callbacks = [callbacks]
elif callbacks is None:
callbacks = []
callbacks.append(HPOCallback(report_func, _get_metric(callbacks) if metric_name is None else metric_name))
callbacks.append(HPOCallback(report_func, get_metric(callbacks) if metric_name is None else metric_name))
return callbacks


def _get_metric(callbacks: list[Callback]) -> str:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
return callback.monitor
error_msg = "Failed to find a metric. There is no ModelCheckpoint in callback list."
raise RuntimeError(error_msg)


def _set_to_validate_every_epoch(callbacks: list[Callback], train_args: dict[str, Any]) -> None:
for callback in callbacks:
if isinstance(callback, AdaptiveTrainScheduling):
Expand Down
26 changes: 26 additions & 0 deletions src/otx/engine/hpo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
import json
from typing import TYPE_CHECKING, Callable

from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

from otx.utils.utils import find_file_recursively

if TYPE_CHECKING:
from pathlib import Path

from lightning import Callback


def find_trial_file(hpo_workdir: Path, trial_id: str) -> Path | None:
"""Find a trial file which store trial record.
Expand Down Expand Up @@ -91,3 +95,25 @@ def get_callable_args_name(module: Callable) -> list[str]:
list[str]: arguments name list.
"""
return list(inspect.signature(module).parameters)


def get_metric(callbacks: list[Callback] | Callback) -> str:
"""Find a metric name from ModelCheckpoint callback.
Args:
callbacks (list[Callback] | Callback): Callback list.
Raises:
RuntimeError: If ModelCheckpoint doesn't exist, the error is raised.
Returns:
str: metric name.
"""
if not isinstance(callbacks, list):
callbacks = [callbacks]

for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
return callback.monitor
msg = "Failed to find a metric. There is no ModelCheckpoint in callback list."
raise RuntimeError(msg)
2 changes: 1 addition & 1 deletion tests/e2e/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import pytest
import yaml
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK
from otx.core.types.task import OTXTaskType
from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK

from tests.e2e.cli.utils import run_main

Expand Down

0 comments on commit edb4389

Please sign in to comment.