-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add image-to-text pipeline * Tidy up for review * Regen golang bindings * Add worker changes * make fields optional * Add ImageToText response * Use new error handling * refactor(runner): apply some small code improvements This commit applies some small code improvements to make the code more consistent with the rest of the pipelines. * feat(worker): ensure all runner errors are forwarded This commit ensures that all errors are propagated from the runner to the worker and then to the orchestrator. --------- Co-authored-by: Rick Staa <[email protected]>
- Loading branch information
Showing
12 changed files
with
730 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import logging | ||
import os | ||
|
||
import torch | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from huggingface_hub import file_download | ||
from transformers import BlipProcessor, BlipForConditionalGeneration | ||
from PIL import Image | ||
|
||
from app.utils.errors import InferenceError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ImageToTextPipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = {} | ||
|
||
self.torch_device = get_torch_device() | ||
folder_name = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model" | ||
) | ||
folder_path = os.path.join(get_model_dir(), folder_name) | ||
# Load fp16 variant if fp16 safetensors files are found in cache | ||
has_fp16_variant = any( | ||
".fp16.safetensors" in fname | ||
for _, _, files in os.walk(folder_path) | ||
for fname in files | ||
) | ||
if self.torch_device != "cpu" and has_fp16_variant: | ||
logger.info("ImageToTextPipeline loading fp16 variant for %s", model_id) | ||
|
||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
|
||
if os.environ.get("BFLOAT16"): | ||
logger.info("ImageToTextPipeline using bfloat16 precision for %s", model_id) | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
self.tm = BlipForConditionalGeneration.from_pretrained( | ||
model_id, | ||
low_cpu_mem_usage=True, | ||
use_safetensors=True, | ||
cache_dir=get_model_dir(), | ||
**kwargs, | ||
).to(self.torch_device) | ||
|
||
self.processor = BlipProcessor.from_pretrained( | ||
model_id, cache_dir=get_model_dir() | ||
) | ||
|
||
def __call__(self, prompt: str, image: Image, **kwargs) -> str: | ||
inputs = self.processor(image, prompt, return_tensors="pt").to( | ||
self.torch_device | ||
) | ||
out = self.tm.generate(**inputs) | ||
|
||
try: | ||
return self.processor.decode(out[0], skip_special_tokens=True) | ||
except Exception as e: | ||
raise InferenceError(original_exception=e) | ||
|
||
def __str__(self) -> str: | ||
return f"ImageToTextPipeline model_id={self.model_id}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import logging | ||
import os | ||
from typing import Annotated, Dict, Tuple, Union | ||
|
||
import torch | ||
|
||
from app.dependencies import get_pipeline | ||
from app.pipelines.base import Pipeline | ||
from app.routes.utils import ( | ||
HTTPError, | ||
ImageToTextResponse, | ||
file_exceeds_max_size, | ||
handle_pipeline_exception, | ||
http_error, | ||
) | ||
from fastapi import APIRouter, Depends, File, Form, UploadFile, status | ||
from fastapi.responses import JSONResponse | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from PIL import Image | ||
|
||
router = APIRouter() | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Pipeline specific error handling configuration. | ||
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { | ||
# Specific error types. | ||
"OutOfMemoryError": ( | ||
"Out of memory error. Try reducing input image resolution.", | ||
status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
) | ||
} | ||
|
||
RESPONSES = { | ||
status.HTTP_200_OK: { | ||
"content": { | ||
"application/json": { | ||
"schema": { | ||
"x-speakeasy-name-override": "data", | ||
} | ||
} | ||
}, | ||
}, | ||
status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, | ||
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, | ||
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError}, | ||
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, | ||
} | ||
|
||
|
||
@router.post( | ||
"/image-to-text", | ||
response_model=ImageToTextResponse, | ||
responses=RESPONSES, | ||
description="Transform image files to text.", | ||
operation_id="genImageToText", | ||
summary="Image To Text", | ||
tags=["generate"], | ||
openapi_extra={"x-speakeasy-name-override": "imageToText"}, | ||
) | ||
@router.post( | ||
"/image-to-text/", | ||
response_model=ImageToTextResponse, | ||
responses=RESPONSES, | ||
include_in_schema=False, | ||
) | ||
async def image_to_text( | ||
image: Annotated[ | ||
UploadFile, File(description="Uploaded image to transform with the pipeline.") | ||
], | ||
prompt: Annotated[ | ||
str, | ||
Form(description="Text prompt(s) to guide transformation."), | ||
] = "", | ||
model_id: Annotated[ | ||
str, | ||
Form(description="Hugging Face model ID used for transformation."), | ||
] = "", | ||
pipeline: Pipeline = Depends(get_pipeline), | ||
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), | ||
): | ||
auth_token = os.environ.get("AUTH_TOKEN") | ||
if auth_token: | ||
if not token or token.credentials != auth_token: | ||
return JSONResponse( | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
headers={"WWW-Authenticate": "Bearer"}, | ||
content=http_error("Invalid bearer token"), | ||
) | ||
|
||
if model_id != "" and model_id != pipeline.model_id: | ||
return JSONResponse( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
content=http_error( | ||
f"pipeline configured with {pipeline.model_id} but called with " | ||
f"{model_id}" | ||
), | ||
) | ||
|
||
if file_exceeds_max_size(image, 50 * 1024 * 1024): | ||
return JSONResponse( | ||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, | ||
content=http_error("File size exceeds limit"), | ||
) | ||
|
||
image = Image.open(image.file).convert("RGB") | ||
try: | ||
return ImageToTextResponse(text=pipeline(prompt=prompt, image=image)) | ||
except Exception as e: | ||
if isinstance(e, torch.cuda.OutOfMemoryError): | ||
torch.cuda.empty_cache() | ||
logger.error(f"ImageToTextPipeline error: {e}") | ||
return handle_pipeline_exception( | ||
e, | ||
default_error_message="Image-to-text pipeline error.", | ||
custom_error_config=PIPELINE_ERROR_CONFIG, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.