diff --git a/src/otx/algorithms/classification/utils/cls_utils.py b/src/otx/algorithms/classification/utils/cls_utils.py index 665b9c15922..6ba33c0996b 100644 --- a/src/otx/algorithms/classification/utils/cls_utils.py +++ b/src/otx/algorithms/classification/utils/cls_utils.py @@ -26,12 +26,15 @@ from otx.api.utils.labels_utils import get_normalized_label_name -def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disable=too-many-locals +def get_multihead_class_info(label_schema: LabelSchemaEntity, normalize_labels=False): # pylint: disable=too-many-locals """Get multihead info by label schema.""" all_groups = label_schema.get_groups(include_empty=False) all_groups_str = [] for g in all_groups: - group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels] + if normalize_labels: + group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels] + else: + group_labels_str = [lbl.name for lbl in g.labels] all_groups_str.append(group_labels_str) single_label_groups = [g for g in all_groups_str if len(g) == 1] @@ -77,7 +80,7 @@ def get_cls_inferencer_configuration(label_schema: LabelSchemaEntity): hierarchical = not multilabel and len(label_schema.get_groups(False)) > 1 multihead_class_info = {} if hierarchical: - multihead_class_info = get_multihead_class_info(label_schema) + multihead_class_info = get_multihead_class_info(label_schema, normalize_labels=True) return { "multilabel": multilabel, "hierarchical": hierarchical, @@ -120,7 +123,7 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c mapi_config[("model_info", "label_ids")] = all_label_ids.strip() hierarchical_config = {} - hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema) + hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema, normalize_labels=True) hierarchical_config["label_tree_edges"] = [] for edge in label_schema.label_tree.edges: # (child, parent) hierarchical_config["label_tree_edges"].append(