Skip to content

Commit

Permalink
Adds sdk label support for multiclass detectors
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-groundlight committed Dec 4, 2024
1 parent 5e9da70 commit 42f53d5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 95 deletions.
22 changes: 0 additions & 22 deletions src/groundlight/binary_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,3 @@ def convert_internal_label_to_display(

logger.warning(f"Unrecognized internal label {label} - leaving it alone as a string.")
return label


def convert_display_label_to_internal(
context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument
label: Union[Label, str],
) -> str:
"""Convert a label that comes from the user into the label string that we send to the server. We
are strict here, and only allow YES/NO.
NOTE: We accept case-insensitive label strings from the user, but we send UPPERCASE labels to
the server. E.g., user inputs "yes" -> the label is returned as "YES".
"""
# NOTE: In the future we should validate against actually supported labels for the detector
if not isinstance(label, str):
raise ValueError(f"Expected a string label, but got {label} of type {type(label)}")
upper = label.upper()
if upper == Label.YES:
return DeprecatedLabel.PASS.value
if upper == Label.NO:
return DeprecatedLabel.FAIL.value

raise ValueError(f"Invalid label string '{label}'. Must be one of '{Label.YES.value}','{Label.NO.value}'.")
41 changes: 29 additions & 12 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from groundlight_openapi_client.api.labels_api import LabelsApi
from groundlight_openapi_client.api.user_api import UserApi
from groundlight_openapi_client.exceptions import NotFoundException, UnauthorizedException
from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest
from groundlight_openapi_client.model.detector_creation_input_request import DetectorCreationInputRequest
from groundlight_openapi_client.model.label_value_request import LabelValueRequest
from groundlight_openapi_client.model.patched_detector_request import PatchedDetectorRequest
from groundlight_openapi_client.model.roi_request import ROIRequest
from model import (
ROI,
BinaryClassificationResult,
Expand All @@ -26,7 +28,7 @@
)
from urllib3.exceptions import InsecureRequestWarning

from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display
from groundlight.binary_labels import Label, convert_internal_label_to_display
from groundlight.config import API_TOKEN_MISSING_HELP_MESSAGE, API_TOKEN_VARIABLE_NAME, DISABLE_TLS_VARIABLE_NAME
from groundlight.encodings import url_encode_dict
from groundlight.images import ByteStreamWrapper, parse_supported_image_types
Expand Down Expand Up @@ -1066,16 +1068,17 @@ def _wait_for_result(
image_query = self._fixup_image_query(image_query)
return image_query

# pylint: disable=duplicate-code
def add_label(
self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None
self, image_query: Union[ImageQuery, str], label: Union[Label, int, str], rois: Union[List[ROI], str, None] = None
):
"""
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
for training detectors, or to correct the results of detectors.
**Example usage**::
gl = Groundlight()
gl = ExperimentalApi()
# Using an ImageQuery object
image_query = gl.ask_ml(detector_id, image_data)
Expand All @@ -1088,27 +1091,41 @@ def add_label(
rois = [ROI(x=100, y=100, width=50, height=50)]
gl.add_label(image_query, "YES", rois=rois)
:param image_query: Either an ImageQuery object (returned from methods like :meth:`ask_ml`) or an image query ID
string starting with "iq_".
:param label: The label value to assign, typically "YES" or "NO" for binary classification detectors.
For multi-class detectors, use one of the defined class names.
:param rois: Optional list of ROI objects defining regions of interest in the image.
Each ROI specifies a bounding box with x, y coordinates and width, height.
:param image_query: Either an ImageQuery object (returned from methods like
`ask_ml`) or an image query ID string starting with "iq_".
:param label: The label value to assign, typically "YES" or "NO" for binary
classification detectors. For multi-class detectors, use one of
the defined class names.
:param rois: Optional list of ROI objects defining regions of interest in the
image. Each ROI specifies a bounding box with x, y coordinates
and width, height.
:return: None
"""
if isinstance(rois, str):
raise TypeError("rois must be a list of ROI objects. CLI support is not implemented")
if isinstance(label, int):
label = str(label)
if isinstance(image_query, ImageQuery):
image_query_id = image_query.id
else:
image_query_id = str(image_query)
# Some old imagequery id's started with "chk_"
# TODO: handle iqe_ for image_queries returned from edge endpoints
if not image_query_id.startswith(("chk_", "iq_")):
raise ValueError(f"Invalid image query id {image_query_id}")
api_label = convert_display_label_to_internal(image_query_id, label)
rois_json = [roi.dict() for roi in rois] if rois else None
request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=rois_json)
geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None
roi_requests = (
[
ROIRequest(label=roi.label, score=roi.score, geometry=geometry)
for roi, geometry in zip(rois, geometry_requests)
]
if rois and geometry_requests
else None
)
request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests)
self.labels_api.create_label(request_params)

def start_inspection(self) -> str:
Expand Down
62 changes: 1 addition & 61 deletions src/groundlight/experimental_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from groundlight_openapi_client.model.verb_enum import VerbEnum
from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule

from groundlight.binary_labels import Label, convert_display_label_to_internal
from groundlight.binary_labels import Label
from groundlight.images import parse_supported_image_types
from groundlight.optional_imports import Image, np

Expand Down Expand Up @@ -499,66 +499,6 @@ def create_roi(self, label: str, top_left: Tuple[float, float], bottom_right: Tu
),
)

# TODO: remove duplicate method on subclass
# pylint: disable=duplicate-code
def add_label(
self, image_query: Union[ImageQuery, str], label: Union[Label, str], rois: Union[List[ROI], str, None] = None
):
"""
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
for training detectors, or to correct the results of detectors.
**Example usage**::
gl = ExperimentalApi()
# Using an ImageQuery object
image_query = gl.ask_ml(detector_id, image_data)
gl.add_label(image_query, "YES")
# Using an image query ID string directly
gl.add_label("iq_abc123", "NO")
# With regions of interest (ROIs)
rois = [ROI(x=100, y=100, width=50, height=50)]
gl.add_label(image_query, "YES", rois=rois)
:param image_query: Either an ImageQuery object (returned from methods like
`ask_ml`) or an image query ID string starting with "iq_".
:param label: The label value to assign, typically "YES" or "NO" for binary
classification detectors. For multi-class detectors, use one of
the defined class names.
:param rois: Optional list of ROI objects defining regions of interest in the
image. Each ROI specifies a bounding box with x, y coordinates
and width, height.
:return: None
"""
if isinstance(rois, str):
raise TypeError("rois must be a list of ROI objects. CLI support is not implemented")
if isinstance(image_query, ImageQuery):
image_query_id = image_query.id
else:
image_query_id = str(image_query)
# Some old imagequery id's started with "chk_"
# TODO: handle iqe_ for image_queries returned from edge endpoints
if not image_query_id.startswith(("chk_", "iq_")):
raise ValueError(f"Invalid image query id {image_query_id}")
api_label = convert_display_label_to_internal(image_query_id, label)
geometry_requests = [BBoxGeometryRequest(**roi.geometry.dict()) for roi in rois] if rois else None
roi_requests = (
[
ROIRequest(label=roi.label, score=roi.score, geometry=geometry)
for roi, geometry in zip(rois, geometry_requests)
]
if rois and geometry_requests
else None
)
request_params = LabelValueRequest(label=api_label, image_query_id=image_query_id, rois=roi_requests)
self.labels_api.create_label(request_params)

def reset_detector(self, detector: Union[str, Detector]) -> None:
"""
Removes all image queries and training data for the given detector. This effectively resets
Expand Down
56 changes: 56 additions & 0 deletions test/unit/test_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from datetime import datetime

import pytest
from groundlight import ExperimentalApi, ApiException


def test_binary_labels(gl_experimental: ExperimentalApi):
name = f"Test binary labels{datetime.utcnow()}"
det = gl_experimental.create_detector(name, "test_query")
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
gl_experimental.add_label(iq1, "YES")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "YES"
gl_experimental.add_label(iq1, "NO")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "NO"
gl_experimental.add_label(iq1, "UNCLEAR")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "UNCLEAR"
with pytest.raises(ApiException) as _:
gl_experimental.add_label(iq1, "MAYBE")

def test_counting_labels(gl_experimental: ExperimentalApi):
name = f"Test binary labels{datetime.utcnow()}"
det = gl_experimental.create_counting_detector(name, "test_query")
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
gl_experimental.add_label(iq1, 0)
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.count == 0
gl_experimental.add_label(iq1, 5)
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.count == 5
with pytest.raises(ApiException) as _:
gl_experimental.add_label(iq1, "MAYBE")
with pytest.raises(ApiException) as _:
gl_experimental.add_label(iq1, -999)

def test_multiclass_labels(gl_experimental: ExperimentalApi):
name = f"Test binary labels{datetime.utcnow()}"
det = gl_experimental.create_multiclass_detector(name, "test_query", class_names=["apple", "banana", "cherry"])
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
gl_experimental.add_label(iq1, "apple")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "apple"
gl_experimental.add_label(iq1, "banana")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "banana"
gl_experimental.add_label(iq1, "cherry")
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "cherry"
# You can submit the index of the class as well
gl_experimental.add_label(iq1, 2)
iq1 = gl_experimental.get_image_query(iq1.id)
assert iq1.result.label == "cherry"
with pytest.raises(ApiException) as _:
gl_experimental.add_label(iq1, "MAYBE")

0 comments on commit 42f53d5

Please sign in to comment.