From 9eecd107604a4bc13e766bb98953520be6e62e13 Mon Sep 17 00:00:00 2001 From: Brandon <132288221+brandon-groundlight@users.noreply.github.com> Date: Mon, 16 Oct 2023 18:31:27 -0700 Subject: [PATCH] Add ask_confident and ask_ml (#99) * Adding ask_confident and ask_fast * Automatically reformatting code * Fixing ask_ml behavior * Adding to test * Automatically reformatting code * set default wait for ask_ml * Unhide wait functions, merging logic, fixed iq_is_answered logic * Automatically reformatting code * Rewriting doc strings in Sphinx style * ask_fast to ask_ml in the tests * fixed sphinx docstring return types * Cleaning the lint trap * Last bits of lint * Making iq submission with inspection work with newly optional patience time * single char typo * Reorder functions to trick Git's LCS alg to be correct * Automatically reformatting code --------- Co-authored-by: Auto-format Bot --- src/groundlight/client.py | 230 ++++++++++++++++++++++----- src/groundlight/internalapi.py | 16 +- test/integration/test_groundlight.py | 51 +++++- 3 files changed, 253 insertions(+), 44 deletions(-) diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 023468e3..9ced39dd 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -2,7 +2,7 @@ import os import time from io import BufferedReader, BytesIO -from typing import Optional, Union +from typing import Callable, Optional, Union from model import Detector, ImageQuery, PaginatedDetectorList, PaginatedImageQueryList from openapi_client import Configuration @@ -13,7 +13,13 @@ from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display from groundlight.config import API_TOKEN_HELP_MESSAGE, API_TOKEN_VARIABLE_NAME from groundlight.images import ByteStreamWrapper, parse_supported_image_types -from groundlight.internalapi import GroundlightApiClient, NotFoundError, iq_is_confident, sanitize_endpoint_url +from groundlight.internalapi import ( + GroundlightApiClient, + NotFoundError, + iq_is_answered, + iq_is_confident, + sanitize_endpoint_url, +) from groundlight.optional_imports import Image, np logger = logging.getLogger("groundlight.sdk") @@ -24,7 +30,8 @@ class ApiTokenError(Exception): class Groundlight: - """Client for accessing the Groundlight cloud service. + """ + Client for accessing the Groundlight cloud service. The API token (auth) is specified through the **GROUNDLIGHT_API_TOKEN** environment variable by default. @@ -70,8 +77,8 @@ def __init__(self, endpoint: Optional[str] = None, api_token: Optional[str] = No If unset, fallback to the environment variable "GROUNDLIGHT_API_TOKEN". :type api_token: str - :return Groundlight client - :rtype Groundlight + :return: Groundlight client + :rtype: Groundlight """ # Specify the endpoint self.endpoint = sanitize_endpoint_url(endpoint) @@ -109,8 +116,8 @@ def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable :param id: the detector id :type id: str or Detector - :return Detector - :rtype Detector + :return: Detector + :rtype: Detector """ if isinstance(id, Detector): @@ -126,8 +133,8 @@ def get_detector_by_name(self, name: str) -> Detector: :param name: the detector name :type name: str - :return Detector - :rtype Detector + :return: Detector + :rtype: Detector """ return self.api_client._get_detector_by_name(name) # pylint: disable=protected-access @@ -141,8 +148,8 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto :param page_size: the page size :type page_size: int - :return PaginatedDetectorList - :rtype PaginatedDetectorList + :return: PaginatedDetectorList + :rtype: PaginatedDetectorList """ obj = self.detectors_api.list_detectors(page=page, page_size=page_size) return PaginatedDetectorList.parse_obj(obj.to_dict()) @@ -170,8 +177,8 @@ def create_detector( :param pipeline_config: the pipeline config :type pipeline_config: str - :return Detector - :rtype Detector + :return: Detector + :rtype: Detector """ detector_creation_input = DetectorCreationInput(name=name, query=query) if confidence_threshold is not None: @@ -206,8 +213,8 @@ def get_or_create_detector( :param pipeline_config: the pipeline config :type pipeline_config: str - :return Detector - :rtype Detector + :return: Detector + :rtype: Detector """ try: existing_detector = self.get_detector_by_name(name) @@ -245,8 +252,8 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b :param id: the image query id :type id: str - :return ImageQuery - :rtype ImageQuery + :return: ImageQuery + :rtype: ImageQuery """ obj = self.image_queries_api.get_image_query(id=id) iq = ImageQuery.parse_obj(obj.to_dict()) @@ -262,8 +269,8 @@ def list_image_queries(self, page: int = 1, page_size: int = 10) -> PaginatedIma :param page_size: the page size :type page_size: int - :return PaginatedImageQueryList - :rtype PaginatedImageQueryList + :return: PaginatedImageQueryList + :rtype: PaginatedImageQueryList """ obj = self.image_queries_api.list_image_queries(page=page, page_size=page_size) image_queries = PaginatedImageQueryList.parse_obj(obj.to_dict()) @@ -276,6 +283,8 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments detector: Union[Detector, str], image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], wait: Optional[float] = None, + patience_time: Optional[float] = None, + confidence_threshold: Optional[float] = None, human_review: Optional[str] = None, want_async: bool = False, inspection_id: Optional[str] = None, @@ -287,14 +296,12 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments :type detector: Detector or str :param image: The image, in several possible formats: - - filename (string) of a jpeg file - byte array or BytesIO or BufferedReader with jpeg bytes - numpy array with values 0-255 and dimensions (H,W,3) in BGR order (Note OpenCV uses BGR not RGB. `img[:, :, ::-1]` will reverse the channels) - PIL Image: Any binary format must be JPEG-encoded already. Any pixel format will get converted to JPEG at high quality before sending to service. - :type image: str or bytes or Image.Image or BytesIO or BufferedReader or np.ndarray :param wait: How long to wait (in seconds) for a confident answer. @@ -315,8 +322,8 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments this is the ID of the inspection to associate with the image query. :type inspection_id: str - :return ImageQuery - :rtype ImageQuery + :return: ImageQuery + :rtype: ImageQuery """ if wait is None: wait = self.DEFAULT_WAIT @@ -326,10 +333,8 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) params = {"detector_id": detector_id, "body": image_bytesio} - if wait == 0: - params["patience_time"] = self.DEFAULT_WAIT - else: - params["patience_time"] = wait + if patience_time is not None: + params["patience_time"] = patience_time if human_review is not None: params["human_review"] = human_review @@ -355,11 +360,89 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments image_query = self.get_image_query(iq_id) if wait > 0: - threshold = self.get_detector(detector).confidence_threshold + if confidence_threshold is None: + threshold = self.get_detector(detector).confidence_threshold + else: + threshold = confidence_threshold image_query = self.wait_for_confident_result(image_query, confidence_threshold=threshold, timeout_sec=wait) return self._fixup_image_query(image_query) + def ask_confident( + self, + detector: Union[Detector, str], + image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], + confidence_threshold: Optional[float] = None, + wait: Optional[float] = None, + ) -> ImageQuery: + """Evaluates an image with Groundlight waiting until an answer above the confidence threshold + of the detector is reached or the wait period has passed. + :param detector: the Detector object, or string id of a detector like `det_12345` + :type detector: Detector or str + + :param image: The image, in several possible formats: + - filename (string) of a jpeg file + - byte array or BytesIO or BufferedReader with jpeg bytes + - numpy array with values 0-255 and dimensions (H,W,3) in BGR order + (Note OpenCV uses BGR not RGB. `img[:, :, ::-1]` will reverse the channels) + - PIL Image + Any binary format must be JPEG-encoded already. Any pixel format will get + converted to JPEG at high quality before sending to service. + :type image: str or bytes or Image.Image or BytesIO or BufferedReader or np.ndarray + + :param confidence_threshold: The confidence threshold to wait for. + If not set, use the detector's confidence threshold. + :type confidence_threshold: float + + :param wait: How long to wait (in seconds) for a confident answer. + :type wait: float + + :return: ImageQuery + :rtype: ImageQuery + """ + return self.submit_image_query( + detector, + image, + confidence_threshold=confidence_threshold, + wait=wait, + ) + + def ask_ml( + self, + detector: Union[Detector, str], + image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], + wait: Optional[float] = None, + ) -> ImageQuery: + """Evaluates an image with Groundlight, getting the first answer Groundlight can provide. + :param detector: the Detector object, or string id of a detector like `det_12345` + :type detector: Detector or str + + :param image: The image, in several possible formats: + - filename (string) of a jpeg file + - byte array or BytesIO or BufferedReader with jpeg bytes + - numpy array with values 0-255 and dimensions (H,W,3) in BGR order + (Note OpenCV uses BGR not RGB. `img[:, :, ::-1]` will reverse the channels) + - PIL Image + Any binary format must be JPEG-encoded already. Any pixel format will get + converted to JPEG at high quality before sending to service. + :type image: str or bytes or Image.Image or BytesIO or BufferedReader or np.ndarray + + :param wait: How long to wait (in seconds) for any answer. + :type wait: float + + :return: ImageQuery + :rtype: ImageQuery + """ + iq = self.submit_image_query( + detector, + image, + wait=0, + ) + if iq_is_answered(iq): + return iq + wait = self.DEFAULT_WAIT if wait is None else wait + return self.wait_for_ml_result(iq, timeout_sec=wait) + def ask_async( self, detector: Union[Detector, str], @@ -423,10 +506,51 @@ def wait_for_confident_result( :param timeout_sec: The maximum number of seconds to wait. :type timeout_sec: float - :return ImageQuery - :rtype ImageQuery + :return: ImageQuery + :rtype: ImageQuery + """ + + def confidence_above_thresh(iq): + return iq_is_confident(iq, confidence_threshold=confidence_threshold) + + return self._wait_for_result(image_query, condition=confidence_above_thresh, timeout_sec=timeout_sec) + + def wait_for_ml_result(self, image_query: Union[ImageQuery, str], timeout_sec: float = 30.0) -> ImageQuery: + """Waits for the first ml result to be returned. + Currently this is done by polling with an exponential back-off. + + :param image_query: An ImageQuery object to poll + :type image_query: ImageQuery or str + + :param confidence_threshold: The minimum confidence level required to return before the timeout. + :type confidence_threshold: float + + :param timeout_sec: The maximum number of seconds to wait. + :type timeout_sec: float + + :return: ImageQuery + :rtype: ImageQuery + """ + return self._wait_for_result(image_query, condition=iq_is_answered, timeout_sec=timeout_sec) + + def _wait_for_result( + self, image_query: Union[ImageQuery, str], condition: Callable, timeout_sec: float = 30.0 + ) -> ImageQuery: + """Performs polling with exponential back-off until the condition is met for the image query. + + :param image_query: An ImageQuery object to poll + :type image_query: ImageQuery or str + + :param condition: A callable that takes an ImageQuery and returns True or False + whether to keep waiting for a better result. + :type condition: Callable + + :param timeout_sec: The maximum number of seconds to wait. + :type timeout_sec: float + + :return: ImageQuery + :rtype: ImageQuery """ - # Convert from image_query_id to ImageQuery if needed. if isinstance(image_query, str): image_query = self.get_image_query(image_query) @@ -436,18 +560,15 @@ def wait_for_confident_result( image_query = self._fixup_image_query(image_query) while True: patience_so_far = time.time() - start_time - if iq_is_confident(image_query, confidence_threshold): - logger.debug(f"Confident answer for {image_query} after {patience_so_far:.1f}s") + if condition(image_query): + logger.debug(f"Answer for {image_query} after {patience_so_far:.1f}s") break if patience_so_far >= timeout_sec: logger.debug(f"Timeout after {timeout_sec:.0f}s waiting for {image_query}") break target_delay = min(patience_so_far + next_delay, timeout_sec) sleep_time = max(target_delay - patience_so_far, 0) - logger.debug( - f"Polling ({target_delay:.1f}/{timeout_sec:.0f}s) {image_query} until" - f" confidence>={confidence_threshold:.3f}" - ) + logger.debug(f"Polling ({target_delay:.1f}/{timeout_sec:.0f}s) {image_query} until result is available") time.sleep(sleep_time) next_delay *= self.POLLING_EXPONENTIAL_BACKOFF image_query = self.get_image_query(image_query.id) @@ -465,8 +586,8 @@ def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str :param label: The string "YES" or the string "NO" in answer to the query. :type label: Label or str - :return None - :rtype None + :return: None + :rtype: None """ if isinstance(image_query, ImageQuery): image_query_id = image_query.id @@ -482,12 +603,27 @@ def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str def start_inspection(self) -> str: """For users with Inspection Reports enabled only. Starts an inspection report and returns the id of the inspection. + + :return: The unique identifier of the inspection. + :rtype: str """ return self.api_client.start_inspection() def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None: """For users with Inspection Reports enabled only. Add/update inspection metadata with the user_provided_key and user_provided_value. + + :param inspection_id: The unique identifier of the inspection. + :type inspection_id: str + + :param user_provided_key: the key in the key/value pair for the inspection metadata. + :type user_provided_key: str + + :param user_provided_value: the value in the key/value pair for the inspection metadata. + :type user_provided_value: str + + :return: None + :rtype: None """ self.api_client.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) @@ -495,10 +631,22 @@ def stop_inspection(self, inspection_id: str) -> str: """For users with Inspection Reports enabled only. Stops an inspection and raises an exception if the response from the server indicates that the inspection was not successfully stopped. - Returns a str with result of the inspection (either PASS or FAIL). + + :param inspection_id: The unique identifier of the inspection. + :type inspection_id: str + + :return: "PASS" or "FAIL" depending on the result of the inspection. + :rtype: str """ return self.api_client.stop_inspection(inspection_id) def update_detector_confidence_threshold(self, detector_id: str, confidence_threshold: float) -> None: - """Updates the confidence threshold of a detector given a detector_id.""" + """Updates the confidence threshold of a detector given a detector_id. + + :param detector_id: The unique identifier of the detector. + :type detector_id: str + + :return: None + :rtype: None + """ self.api_client.update_detector_confidence_threshold(detector_id, confidence_threshold) diff --git a/src/groundlight/internalapi.py b/src/groundlight/internalapi.py index d317b30b..3d85eb37 100644 --- a/src/groundlight/internalapi.py +++ b/src/groundlight/internalapi.py @@ -71,6 +71,17 @@ def iq_is_confident(iq: ImageQuery, confidence_threshold: float) -> bool: return iq.result.confidence >= confidence_threshold +def iq_is_answered(iq: ImageQuery) -> bool: + """Returns True if the image query has a ML or human label. + Placeholder and special labels (out of domain) have confidences exactly 0.5 + """ + if iq.result.confidence is None: + # Human label + return True + placeholder_confidence = 0.5 + return iq.result.confidence > placeholder_confidence + + class InternalApiError(ApiException, RuntimeError): # TODO: We should really avoid this double inheritance since # both `ApiException` and `RuntimeError` are subclasses of @@ -232,9 +243,9 @@ def _get_detector_by_name(self, name: str) -> Detector: def submit_image_query_with_inspection( # noqa: PLR0913 # pylint: disable=too-many-arguments self, detector_id: str, - patience_time: float, body: ByteStreamWrapper, inspection_id: str, + patience_time: Optional[float] = None, human_review: str = "DEFAULT", ) -> str: """Submits an image query to the API and returns the ID of the image query. @@ -246,8 +257,9 @@ def submit_image_query_with_inspection( # noqa: PLR0913 # pylint: disable=too-m params: Dict[str, Union[str, float, bool]] = { "inspection_id": inspection_id, "predictor_id": detector_id, - "patience_time": patience_time, } + if patience_time is not None: + params["patience_time"] = float(patience_time) # In the API, 'send_notification' is used to control human_review escalation. This will eventually # be deprecated, but for now we need to support it in the following manner: diff --git a/test/integration/test_groundlight.py b/test/integration/test_groundlight.py index 7697939b..71e9da1e 100644 --- a/test/integration/test_groundlight.py +++ b/test/integration/test_groundlight.py @@ -9,12 +9,13 @@ import pytest from groundlight import Groundlight from groundlight.binary_labels import VALID_DISPLAY_LABELS, DeprecatedLabel, Label, convert_internal_label_to_display -from groundlight.internalapi import InternalApiError, NotFoundError +from groundlight.internalapi import InternalApiError, NotFoundError, iq_is_answered from groundlight.optional_imports import * from groundlight.status_codes import is_user_error from model import ClassificationResult, Detector, ImageQuery, PaginatedDetectorList, PaginatedImageQueryList DEFAULT_CONFIDENCE_THRESHOLD = 0.9 +IQ_IMPROVEMENT_THRESHOLD = 0.75 def is_valid_display_result(result: Any) -> bool: @@ -163,6 +164,41 @@ def test_get_detector_by_name(gl: Groundlight, detector: Detector): gl.get_detector_by_name(name="not a real name") +def test_ask_confident(gl: Groundlight, detector: Detector): + _image_query = gl.ask_confident(detector=detector.id, image="test/assets/dog.jpeg", wait=10) + assert str(_image_query) + assert isinstance(_image_query, ImageQuery) + assert is_valid_display_result(_image_query.result) + + +def test_ask_ml(gl: Groundlight, detector: Detector): + _image_query = gl.ask_ml(detector=detector.id, image="test/assets/dog.jpeg", wait=10) + assert str(_image_query) + assert isinstance(_image_query, ImageQuery) + assert is_valid_display_result(_image_query.result) + + +def test_submit_image_query(gl: Groundlight, detector: Detector): + def validate_image_query(_image_query: ImageQuery): + assert str(_image_query) + assert isinstance(_image_query, ImageQuery) + assert is_valid_display_result(_image_query.result) + + _image_query = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", wait=10) + validate_image_query(_image_query) + _image_query = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", wait=3) + validate_image_query(_image_query) + _image_query = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", wait=10, patience_time=20) + validate_image_query(_image_query) + _image_query = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", human_review="NEVER") + validate_image_query(_image_query) + _image_query = gl.submit_image_query( + detector=detector.id, image="test/assets/dog.jpeg", wait=180, confidence_threshold=0.75 + ) + validate_image_query(_image_query) + assert _image_query.result.confidence >= IQ_IMPROVEMENT_THRESHOLD + + def test_submit_image_query_blocking(gl: Groundlight, detector: Detector): _image_query = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", wait=10) assert str(_image_query) @@ -489,6 +525,19 @@ def submit_noisy_image(image, label=None): ), f"The detector {detector} quality has not improved after two minutes q.v. {new_dog_query}, {new_cat_query}" +def test_ask_method_quality(gl: Groundlight, detector: Detector): + # asks for some level of quality on how fast ask_ml is and that we will get a confident result from ask_confident + fast_always_yes_iq = gl.ask_ml(detector=detector.id, image="test/assets/dog.jpeg", wait=0) + assert iq_is_answered(fast_always_yes_iq) + name = f"Test {datetime.utcnow()}" # Need a unique name + query = "Is there a dog?" + detector = gl.create_detector(name=name, query=query, confidence_threshold=0.8) + fast_iq = gl.ask_ml(detector=detector.id, image="test/assets/dog.jpeg", wait=0) + assert iq_is_answered(fast_iq) + confident_iq = gl.ask_confident(detector=detector.id, image="test/assets/dog.jpeg", wait=180) + assert confident_iq.result.confidence > IQ_IMPROVEMENT_THRESHOLD + + def test_start_inspection(gl: Groundlight): inspection_id = gl.start_inspection()