Skip to content

Commit

Permalink
Detect num_classes information for the H-label classification (#3064)
Browse files Browse the repository at this point in the history
* Initial commit

* Add the logic for auto hlabel information

* Fix precommit
  • Loading branch information
sungmanc authored Mar 13, 2024
1 parent 6b55251 commit 7b7de3c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
8 changes: 8 additions & 0 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,14 @@ def instantiate_model(self, model_config: Namespace) -> tuple:
warn(warning_msg, stacklevel=0)
model_config.init_args.num_classes = num_classes

# Hlabel classification
from otx.core.data.dataset.classification import HLabelInfo

if isinstance(self.datamodule.label_info, HLabelInfo):
hlabel_info = self.datamodule.label_info
model_config.init_args.num_multiclass_heads = hlabel_info.num_multiclass_heads
model_config.init_args.num_multilabel_classes = hlabel_info.num_multilabel_classes

# Parses the OTXModel separately to update num_classes.
model_parser = ArgumentParser()
model_parser.add_subclass_arguments(OTXModel, "model", required=False, fail_untyped=False)
Expand Down
8 changes: 8 additions & 0 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ def get_model(self, model_name: str | None = None, label_info: LabelInfo | None
if label_info is not None:
num_classes = label_info.num_classes
self.config["model"]["init_args"]["num_classes"] = num_classes

from otx.core.data.dataset.classification import HLabelInfo

if isinstance(label_info, HLabelInfo):
init_args = self.config["model"]["init_args"]
init_args["num_multiclass_heads"] = label_info.num_multiclass_heads
init_args["num_multilabel_classes"] = label_info.num_multilabel_classes

logger.warning(f"Set Default Model: {self.config['model']}")
return instantiate_class(args=(), init=self.config["model"])

Expand Down
8 changes: 0 additions & 8 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def test_otx_e2e(
"1" if task in ("zero_shot_visual_prompting") else "2",
*fxt_cli_override_command_per_task[task],
]
# H-Label-CLS need to add --metric
if task in ("h_label_cls"):
command_cfg.extend(["--metric.num_multiclass_heads", "2"])
command_cfg.extend(["--metric.num_multilabel_classes", "3"])

run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess)

Expand Down Expand Up @@ -105,10 +101,6 @@ def test_otx_e2e(
"--checkpoint",
str(ckpt_files[-1]),
]
# H-Label-CLS need to add --metric
if task in ("h_label_cls"):
command_cfg.extend(["--metric.num_multiclass_heads", "2"])
command_cfg.extend(["--metric.num_multilabel_classes", "3"])

run_main(command_cfg=command_cfg, open_subprocess=fxt_open_subprocess)

Expand Down
7 changes: 1 addition & 6 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,7 @@ def fxt_cli_override_command_per_task() -> dict:
return {
"multi_class_cls": [],
"multi_label_cls": [],
"h_label_cls": [
"--model.num_multiclass_heads",
"2",
"--model.num_multilabel_classes",
"3",
],
"h_label_cls": [],
"detection": [],
"rotated_detection": [],
"instance_segmentation": [],
Expand Down

0 comments on commit 7b7de3c

Please sign in to comment.