Skip to content

Commit

Permalink
Update Label Info handling (#4127)
Browse files Browse the repository at this point in the history
* Update h-cls info

* Revert h-cls head to linear one

* Cosmetic changes

* Add arrow-specific labels management logic for cls

* Update export logic

* Update label info usage

* Update unit tests

* Fix linter

* Fix unit tests

* Fix linter

* Consider multilabel scenario in h-cls

* Update dataset docstring

* Add unit tests

* Don't preprocess h-cls dataset for arrow

* Fimussing labels in multilabel training

* Revert hcls head for effnet b0

* Update converter to pick up cls task
  • Loading branch information
sovrasov authored Dec 4, 2024
1 parent cf035f6 commit c6e2952
Show file tree
Hide file tree
Showing 29 changed files with 278 additions and 82 deletions.
9 changes: 3 additions & 6 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from otx.algo.classification.backbones.efficientnet import EFFICIENTNET_VERSION, OTXEfficientNet
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:

return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
10 changes: 4 additions & 6 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
HierarchicalLinearClsHead,
LinearClsHead,
MultiLabelNonLinearClsHead,
SemiSLLinearClsHead,
Expand Down Expand Up @@ -313,14 +313,12 @@ def _build_model(self, head_config: dict) -> nn.Module:

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
in_channels = 960 if self.mode == "large" else 576

return HLabelClassifier(
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=in_channels),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/timm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from otx.algo.classification.backbones.timm import TimmBackbone, TimmModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses.asymmetric_angular_loss_with_ignore import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -272,11 +272,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
**copied_head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**copied_head_config, in_channels=backbone.num_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
9 changes: 3 additions & 6 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from otx.algo.classification.backbones.torchvision import TorchvisionBackbone, TVModelType
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
LinearClsHead,
MultiLabelLinearClsHead,
SemiSLLinearClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.necks.gap import GlobalAveragePooling
from otx.algo.classification.utils import get_classification_layers
from otx.core.data.entity.classification import (
Expand Down Expand Up @@ -315,11 +315,8 @@ def _build_model(self, head_config: dict) -> nn.Module:
backbone = TorchvisionBackbone(backbone=self.backbone, pretrained=self.pretrained)
return HLabelClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.in_features,
**head_config,
),
neck=GlobalAveragePooling(dim=2),
head=HierarchicalLinearClsHead(**head_config, in_channels=backbone.in_features),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
)
Expand Down
8 changes: 2 additions & 6 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from otx.algo.classification.backbones.vision_transformer import VIT_ARCH_TYPE, VisionTransformer
from otx.algo.classification.classifier import HLabelClassifier, ImageClassifier, SemiSLClassifier
from otx.algo.classification.heads import (
HierarchicalCBAMClsHead,
MultiLabelLinearClsHead,
SemiSLVisionTransformerClsHead,
VisionTransformerClsHead,
)
from otx.algo.classification.losses import AsymmetricAngularLossWithIgnore
from otx.algo.classification.mobilenet_v3 import HierarchicalLinearClsHead
from otx.algo.classification.utils import get_classification_layers
from otx.algo.explain.explain_algo import ViTReciproCAM, feature_vector_fn
from otx.algo.utils.support_otx_v1 import OTXv1Helper
Expand Down Expand Up @@ -466,11 +466,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
return HLabelClassifier(
backbone=vit_backbone,
neck=None,
head=HierarchicalCBAMClsHead(
in_channels=vit_backbone.embed_dim,
step_size=1,
**head_config,
),
head=HierarchicalLinearClsHead(**head_config, in_channels=vit_backbone.embed_dim),
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
init_cfg=init_cfg,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.BGR,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.task_type = task_type
super().__init__(
Expand Down
7 changes: 6 additions & 1 deletion src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class OTXDataset(Dataset, Generic[T_OTXDataEntity]):
max_refetch: Maximum number of images to fetch in cache
image_color_channel: Color channel of images
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
data_format: Source data format, which was originally passed to datumaro (could be arrow for instance).
"""

Expand All @@ -83,6 +84,7 @@ def __init__(
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
) -> None:
self.dm_subset = dm_subset
self.transforms = transforms
Expand All @@ -92,8 +94,11 @@ def __init__(
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image
self.data_format = data_format

if self.dm_subset.categories():
if self.dm_subset.categories() and data_format == "arrow":
self.label_info = LabelInfo.from_dm_label_groups_arrow(self.dm_subset.categories()[AnnotationType.label])
elif self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
self.label_info = NullLabelInfo()
Expand Down
64 changes: 42 additions & 22 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,21 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# multilabel information stored in 'multi_label_ids' attribute when the source format is arrow
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
labels = torch.as_tensor([ann.label for ann in label_anns])
label_ids.add(label.label)
labels = torch.as_tensor(list(label_ids))

entity = MultilabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -128,13 +132,22 @@ def __init__(self, **kwargs) -> None:
self.dm_categories = self.dm_subset.categories()[AnnotationType.label]

# Hlabel classification used HLabelInfo to insert the HLabelData.
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)
if self.data_format == "arrow":
# arrow format stores label IDs as names, have to deal with that here
self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories)
else:
self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories)

self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names))
self.id_to_name_mapping[""] = ""

if self.label_info.num_multiclass_heads == 0:
msg = "The number of multiclass heads should be larger than 0."
raise ValueError(msg)

for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)
if self.data_format != "arrow":
for dm_item in self.dm_subset:
self._add_ancestors(dm_item.annotations)

def _add_ancestors(self, label_anns: list[Label]) -> None:
"""Add ancestors recursively if some label miss the ancestor information.
Expand All @@ -149,14 +162,16 @@ def _add_ancestors(self, label_anns: list[Label]) -> None:
"""

def _label_idx_to_name(idx: int) -> str:
return self.label_info.label_names[idx]
return self.dm_categories[idx].name

def _label_name_to_idx(name: str) -> int:
indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name]
return indices[0]

def _get_label_group_idx(label_name: str) -> int:
if isinstance(self.label_info, HLabelInfo):
if self.data_format == "arrow":
return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0]
return self.label_info.class_to_group_idx[label_name][0]
msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}"
raise ValueError(msg)
Expand Down Expand Up @@ -197,17 +212,22 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
label_ids = set()
for ann in item.annotations:
# in h-cls scenario multilabel information stored in 'multi_label_ids' attribute
if "multi_label_ids" in ann.attributes:
for lbl_idx in ann.attributes["multi_label_ids"]:
label_ids.add(lbl_idx)

if isinstance(ann, Label):
label_anns.append(ann)
label_ids.add(ann.label)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
hlabel_labels = self._convert_label_to_hlabel_format(label_anns, ignored_labels)
label_ids.add(label.label)

hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels)

entity = HlabelClsDataEntity(
image=img_data,
Expand Down Expand Up @@ -256,18 +276,18 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
class_indices[i] = -1

for ann in label_anns:
ann_name = self.dm_categories.items[ann.label].name
ann_parent = self.dm_categories.items[ann.label].parent
if self.data_format == "arrow":
# skips unknown labels for instance, the empty one
if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping:
continue
ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name]
else:
ann_name = self.dm_categories.items[ann.label].name
group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name]
(parent_group_idx, parent_in_group_idx) = (
self.label_info.class_to_group_idx[ann_parent] if ann_parent else (None, None)
)

if group_idx < num_multiclass_heads:
class_indices[group_idx] = in_group_idx
if parent_group_idx is not None and parent_in_group_idx is not None:
class_indices[parent_group_idx] = parent_in_group_idx
elif not ignored_labels or ann.label not in ignored_labels:
elif ann.label not in ignored_labels:
class_indices[num_multiclass_heads + in_group_idx] = 1
else:
class_indices[num_multiclass_heads + in_group_idx] = -1
Expand Down
4 changes: 3 additions & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
self.dm_subset = self._get_single_bbox_dataset(dm_subset)

if self.dm_subset.categories():
kp_labels = self.dm_subset.categories()[AnnotationType.points][0].labels
self.label_info = LabelInfo(
label_names=self.dm_subset.categories()[AnnotationType.points][0].labels,
label_names=kp_labels,
label_groups=[],
label_ids=[str(i) for i in range(len(kp_labels))],
)
else:
self.label_info = NullLabelInfo()
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
stack_images: bool = True,
to_tv_image: bool = True,
ignore_index: int = 255,
data_format: str = "",
) -> None:
super().__init__(
dm_subset,
Expand All @@ -187,6 +188,7 @@ def __init__(
label_names=self.label_info.label_names,
label_groups=self.label_info.label_groups,
ignore_index=ignore_index,
label_ids=self.label_info.label_ids,
)
self.ignore_index = ignore_index

Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create( # noqa: PLR0911
dm_subset: DmDataset,
cfg_subset: SubsetConfig,
mem_cache_handler: MemCacheHandlerBase,
data_format: str,
mem_cache_img_max_size: tuple[int, int] | None = None,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
Expand All @@ -85,6 +86,7 @@ def create( # noqa: PLR0911
common_kwargs = {
"dm_subset": dm_subset,
"transforms": transforms,
"data_format": data_format,
"mem_cache_handler": mem_cache_handler,
"mem_cache_img_max_size": mem_cache_img_max_size,
"image_color_channel": image_color_channel,
Expand Down
10 changes: 3 additions & 7 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ def __init__( # noqa: PLR0913
self.subsets: dict[str, OTXDataset] = {}
self.save_hyperparameters(ignore=["input_size"])

# TODO (Jaeguk): This is workaround for a bug in Datumaro.
# These lines should be removed after next datumaro release.
# https://github.com/openvinotoolkit/datumaro/pull/1223/files
from datumaro.plugins.data_formats.video import VIDEO_EXTENSIONS

VIDEO_EXTENSIONS.append(".mp4")

dataset = DmDataset.import_from(self.data_root, format=self.data_format)
if self.task != "H_LABEL_CLS":
dataset = pre_filtering(
Expand Down Expand Up @@ -193,6 +186,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
mem_cache_handler=mem_cache_handler,
data_format=self.data_format,
mem_cache_img_max_size=mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
Expand Down Expand Up @@ -237,6 +231,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[transform_key] = unlabeled_dataset
else:
Expand All @@ -251,6 +246,7 @@ def __init__( # noqa: PLR0913
include_polygons=include_polygons,
ignore_index=ignore_index,
vpm_config=vpm_config,
data_format=self.data_format,
)
self.subsets[name] = unlabeled_dataset

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/pre_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def remove_unused_labels(dataset: DmDataset, data_format: str, ignore_index: int
used_labels = [0, *used_labels]
if data_format == "common_semantic_segmentation_with_subset_dirs" and len(original_categories) < len(used_labels):
msg = (
"There are labeles mismatch in dataset categories and actuall categories comes from semantic masks."
"There are labels mismatch in dataset categories and actual categories comes from semantic masks."
"Please, check `dataset_meta.json` file."
)
raise ValueError(msg)
Expand Down
Loading

0 comments on commit c6e2952

Please sign in to comment.