diff --git a/edenai_apis/apis/winstonai/winstonai_api.py b/edenai_apis/apis/winstonai/winstonai_api.py index d76fb2b2..20cd33fe 100644 --- a/edenai_apis/apis/winstonai/winstonai_api.py +++ b/edenai_apis/apis/winstonai/winstonai_api.py @@ -1,10 +1,13 @@ import json +from http import HTTPStatus from typing import Dict, Sequence, Any, Optional from uuid import uuid4 import requests from edenai_apis.apis.winstonai.config import WINSTON_AI_API_URL from edenai_apis.features import ProviderInterface, TextInterface, ImageInterface -from edenai_apis.features.image.ai_detection.ai_detection_dataclass import AiDetectionDataClass as ImageAiDetectionDataclass +from edenai_apis.features.image.ai_detection.ai_detection_dataclass import ( + AiDetectionDataClass as ImageAiDetectionDataclass, +) from edenai_apis.features.text.ai_detection.ai_detection_dataclass import ( AiDetectionDataClass, AiDetectionItem, @@ -36,16 +39,19 @@ def __init__(self, api_keys: Optional[Dict[str, Any]] = None): "Authorization": f'Bearer {self.api_settings["api_key"]}', } - def image__ai_detection(self, file: Optional[str] = None, file_url: Optional[str] = None) -> ResponseType[ImageAiDetectionDataclass]: + def image__ai_detection( + self, file: Optional[str] = None, file_url: Optional[str] = None + ) -> ResponseType[ImageAiDetectionDataclass]: if not file_url and not file: raise ProviderException("file or file_url required") - payload = json.dumps({ - "url": file_url or upload_file_to_s3(file, file) - }) + payload = json.dumps({"url": file_url or upload_file_to_s3(file, file)}) response = requests.request( - "POST", f"{self.api_url}/image-detection", headers=self.headers, data=payload + "POST", + f"{self.api_url}/image-detection", + headers=self.headers, + data=payload, ) if response.status_code != 200: @@ -58,7 +64,6 @@ def image__ai_detection(self, file: Optional[str] = None, file_url: Optional[st if score is None: raise ProviderException(response.json()) - standardized_response = ImageAiDetectionDataclass( ai_score=score, prediction=prediction, @@ -87,6 +92,9 @@ def text__ai_detection( "POST", f"{self.api_url}/predict", headers=self.headers, data=payload ) + if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: + raise ProviderException("Internal Server Error") + if response.status_code != 200: raise ProviderException(response.json(), code=response.status_code) @@ -100,7 +108,7 @@ def text__ai_detection( items: Sequence[AiDetectionItem] = [ AiDetectionItem( text=sentence["text"], - ai_score=1-(sentence["score"] / 100), + ai_score=1 - (sentence["score"] / 100), prediction=AiDetectionItem.set_label_based_on_score( 1 - (sentence["score"] / 100) ), @@ -108,8 +116,7 @@ def text__ai_detection( for sentence in sentences ] - standardized_response = AiDetectionDataClass( - ai_score=1-score, items=items) + standardized_response = AiDetectionDataClass(ai_score=1 - score, items=items) return ResponseType[AiDetectionDataClass]( original_response=original_response,