From 42f53d5209bbe418328ee751db8d46e37f4b4bd1 Mon Sep 17 00:00:00 2001 From: brandon Date: Wed, 4 Dec 2024 00:57:36 +0000 Subject: [PATCH] Adds sdk label support for multiclass detectors --- src/groundlight/binary_labels.py | 22 ---------- src/groundlight/client.py | 41 +++++++++++++------ src/groundlight/experimental_api.py | 62 +---------------------------- test/unit/test_labels.py | 56 ++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 95 deletions(-) create mode 100644 test/unit/test_labels.py diff --git a/src/groundlight/binary_labels.py b/src/groundlight/binary_labels.py index c1d20470..557a4245 100644 --- a/src/groundlight/binary_labels.py +++ b/src/groundlight/binary_labels.py @@ -53,25 +53,3 @@ def convert_internal_label_to_display( logger.warning(f"Unrecognized internal label {label} - leaving it alone as a string.") return label - - -def convert_display_label_to_internal( - context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument - label: Union[Label, str], -) -> str: - """Convert a label that comes from the user into the label string that we send to the server. We - are strict here, and only allow YES/NO. - - NOTE: We accept case-insensitive label strings from the user, but we send UPPERCASE labels to - the server. E.g., user inputs "yes" -> the label is returned as "YES". - """ - # NOTE: In the future we should validate against actually supported labels for the detector - if not isinstance(label, str): - raise ValueError(f"Expected a string label, but got {label} of type {type(label)}") - upper = label.upper() - if upper == Label.YES: - return DeprecatedLabel.PASS.value - if upper == Label.NO: - return DeprecatedLabel.FAIL.value - - raise ValueError(f"Invalid label string '{label}'. Must be one of '{Label.YES.value}','{Label.NO.value}'.") diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 14ab15af..14bb9a2d 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -13,9 +13,11 @@ from groundlight_openapi_client.api.labels_api import LabelsApi from groundlight_openapi_client.api.user_api import UserApi from groundlight_openapi_client.exceptions import NotFoundException, UnauthorizedException +from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest from groundlight_openapi_client.model.detector_creation_input_request import DetectorCreationInputRequest from groundlight_openapi_client.model.label_value_request import LabelValueRequest from groundlight_openapi_client.model.patched_detector_request import PatchedDetectorRequest +from groundlight_openapi_client.model.roi_request import ROIRequest from model import ( ROI, BinaryClassificationResult, @@ -26,7 +28,7 @@ ) from urllib3.exceptions import InsecureRequestWarning -from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display +from groundlight.binary_labels import Label, convert_internal_label_to_display from groundlight.config import API_TOKEN_MISSING_HELP_MESSAGE, API_TOKEN_VARIABLE_NAME, DISABLE_TLS_VARIABLE_NAME from groundlight.encodings import url_encode_dict from groundlight.images import ByteStreamWrapper, parse_supported_image_types @@ -1066,8 +1068,9 @@ def _wait_for_result( image_query = self._fixup_image_query(image_query) return image_query + # pylint: disable=duplicate-code def add_label( - self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None + self, image_query: Union[ImageQuery, str], label: Union[Label, int, str], rois: Union[List[ROI], str, None] = None ): """ Provide a new label (annotation) for an image query. This is used to provide ground-truth labels @@ -1075,7 +1078,7 @@ def add_label( **Example usage**:: - gl = Groundlight() + gl = ExperimentalApi() # Using an ImageQuery object image_query = gl.ask_ml(detector_id, image_data) @@ -1088,27 +1091,41 @@ def add_label( rois = [ROI(x=100, y=100, width=50, height=50)] gl.add_label(image_query, "YES", rois=rois) - :param image_query: Either an ImageQuery object (returned from methods like :meth:`ask_ml`) or an image query ID - string starting with "iq_". - :param label: The label value to assign, typically "YES" or "NO" for binary classification detectors. - For multi-class detectors, use one of the defined class names. - :param rois: Optional list of ROI objects defining regions of interest in the image. - Each ROI specifies a bounding box with x, y coordinates and width, height. + :param image_query: Either an ImageQuery object (returned from methods like + `ask_ml`) or an image query ID string starting with "iq_". + + :param label: The label value to assign, typically "YES" or "NO" for binary + classification detectors. For multi-class detectors, use one of + the defined class names. + + :param rois: Optional list of ROI objects defining regions of interest in the + image. Each ROI specifies a bounding box with x, y coordinates + and width, height. :return: None """ if isinstance(rois, str): raise TypeError("rois must be a list of ROI objects. CLI support is not implemented") + if isinstance(label, int): + label = str(label) if isinstance(image_query, ImageQuery): image_query_id = image_query.id else: image_query_id = str(image_query) # Some old imagequery id's started with "chk_" + # TODO: handle iqe_ for image_queries returned from edge endpoints if not image_query_id.startswith(("chk_", "iq_")): raise ValueError(f"Invalid image query id {image_query_id}") - api_label = convert_display_label_to_internal(image_query_id, label) - rois_json = [roi.dict() for roi in rois] if rois else None - request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=rois_json) + geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None + roi_requests = ( + [ + ROIRequest(label=roi.label, score=roi.score, geometry=geometry) + for roi, geometry in zip(rois, geometry_requests) + ] + if rois and geometry_requests + else None + ) + request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests) self.labels_api.create_label(request_params) def start_inspection(self) -> str: diff --git a/src/groundlight/experimental_api.py b/src/groundlight/experimental_api.py index aac00fa5..1a8c34cd 100644 --- a/src/groundlight/experimental_api.py +++ b/src/groundlight/experimental_api.py @@ -32,7 +32,7 @@ from groundlight_openapi_client.model.verb_enum import VerbEnum from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule -from groundlight.binary_labels import Label, convert_display_label_to_internal +from groundlight.binary_labels import Label from groundlight.images import parse_supported_image_types from groundlight.optional_imports import Image, np @@ -499,66 +499,6 @@ def create_roi(self, label: str, top_left: Tuple[float, float], bottom_right: Tu ), ) - # TODO: remove duplicate method on subclass - # pylint: disable=duplicate-code - def add_label( - self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None - ): - """ - Provide a new label (annotation) for an image query. This is used to provide ground-truth labels - for training detectors, or to correct the results of detectors. - - **Example usage**:: - - gl = ExperimentalApi() - - # Using an ImageQuery object - image_query = gl.ask_ml(detector_id, image_data) - gl.add_label(image_query, "YES") - - # Using an image query ID string directly - gl.add_label("iq_abc123", "NO") - - # With regions of interest (ROIs) - rois = [ROI(x=100, y=100, width=50, height=50)] - gl.add_label(image_query, "YES", rois=rois) - - :param image_query: Either an ImageQuery object (returned from methods like - `ask_ml`) or an image query ID string starting with "iq_". - - :param label: The label value to assign, typically "YES" or "NO" for binary - classification detectors. For multi-class detectors, use one of - the defined class names. - - :param rois: Optional list of ROI objects defining regions of interest in the - image. Each ROI specifies a bounding box with x, y coordinates - and width, height. - - :return: None - """ - if isinstance(rois, str): - raise TypeError("rois must be a list of ROI objects. CLI support is not implemented") - if isinstance(image_query, ImageQuery): - image_query_id = image_query.id - else: - image_query_id = str(image_query) - # Some old imagequery id's started with "chk_" - # TODO: handle iqe_ for image_queries returned from edge endpoints - if not image_query_id.startswith(("chk_", "iq_")): - raise ValueError(f"Invalid image query id {image_query_id}") - api_label = convert_display_label_to_internal(image_query_id, label) - geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None - roi_requests = ( - [ - ROIRequest(label=roi.label, score=roi.score, geometry=geometry) - for roi, geometry in zip(rois, geometry_requests) - ] - if rois and geometry_requests - else None - ) - request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=roi_requests) - self.labels_api.create_label(request_params) - def reset_detector(self, detector: Union[str, Detector]) -> None: """ Removes all image queries and training data for the given detector. This effectively resets diff --git a/test/unit/test_labels.py b/test/unit/test_labels.py new file mode 100644 index 00000000..474c6ae0 --- /dev/null +++ b/test/unit/test_labels.py @@ -0,0 +1,56 @@ +from datetime import datetime + +import pytest +from groundlight import ExperimentalApi, ApiException + + +def test_binary_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_detector(name, "test_query") + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, "YES") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "YES" + gl_experimental.add_label(iq1, "NO") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "NO" + gl_experimental.add_label(iq1, "UNCLEAR") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "UNCLEAR" + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE") + +def test_counting_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_counting_detector(name, "test_query") + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, 0) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.count == 0 + gl_experimental.add_label(iq1, 5) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.count == 5 + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE") + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, -999) + +def test_multiclass_labels(gl_experimental: ExperimentalApi): + name = f"Test binary labels{datetime.utcnow()}" + det = gl_experimental.create_multiclass_detector(name, "test_query", class_names=["apple", "banana", "cherry"]) + iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg") + gl_experimental.add_label(iq1, "apple") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "apple" + gl_experimental.add_label(iq1, "banana") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "banana" + gl_experimental.add_label(iq1, "cherry") + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "cherry" + # You can submit the index of the class as well + gl_experimental.add_label(iq1, 2) + iq1 = gl_experimental.get_image_query(iq1.id) + assert iq1.result.label == "cherry" + with pytest.raises(ApiException) as _: + gl_experimental.add_label(iq1, "MAYBE")