Skip to content

Commit

Permalink
Add image-to-text pipeline (#223)
Browse files Browse the repository at this point in the history
* Add image-to-text pipeline

* Tidy up for review

* Regen golang bindings

* Add worker changes

* make fields optional

* Add ImageToText response

* Use new error handling

* refactor(runner): apply some small code improvements

This commit applies some small code improvements to make the code more
consistent with the rest of the pipelines.

* feat(worker): ensure all runner errors are forwarded

This commit ensures that all errors are propagated from the runner to the
worker and then to the orchestrator.

---------

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
mjh1 and rickstaa authored Oct 21, 2024
1 parent 40fa0c2 commit 7727d8b
Show file tree
Hide file tree
Showing 12 changed files with 730 additions and 59 deletions.
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
case "llm":
from app.pipelines.llm import LLMPipeline
return LLMPipeline(model_id)
case "image-to-text":
from app.pipelines.image_to_text import ImageToTextPipeline

return ImageToTextPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -94,6 +98,9 @@ def load_route(pipeline: str) -> any:
case "llm":
from app.routes import llm
return llm.router
case "image-to-text":
from app.routes import image_to_text
return image_to_text.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
66 changes: 66 additions & 0 deletions runner/app/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
import os

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image

from app.utils.errors import InferenceError

logger = logging.getLogger(__name__)


class ImageToTextPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {}

self.torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if self.torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToTextPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("ImageToTextPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

self.tm = BlipForConditionalGeneration.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=get_model_dir(),
**kwargs,
).to(self.torch_device)

self.processor = BlipProcessor.from_pretrained(
model_id, cache_dir=get_model_dir()
)

def __call__(self, prompt: str, image: Image, **kwargs) -> str:
inputs = self.processor(image, prompt, return_tensors="pt").to(
self.torch_device
)
out = self.tm.generate(**inputs)

try:
return self.processor.decode(out[0], skip_special_tokens=True)
except Exception as e:
raise InferenceError(original_exception=e)

def __str__(self) -> str:
return f"ImageToTextPipeline model_id={self.model_id}"
117 changes: 117 additions & 0 deletions runner/app/routes/image_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import os
from typing import Annotated, Dict, Tuple, Union

import torch

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import (
HTTPError,
ImageToTextResponse,
file_exceeds_max_size,
handle_pipeline_exception,
http_error,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image

router = APIRouter()

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"OutOfMemoryError": (
"Out of memory error. Try reducing input image resolution.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
"application/json": {
"schema": {
"x-speakeasy-name-override": "data",
}
}
},
},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/image-to-text",
response_model=ImageToTextResponse,
responses=RESPONSES,
description="Transform image files to text.",
operation_id="genImageToText",
summary="Image To Text",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "imageToText"},
)
@router.post(
"/image-to-text/",
response_model=ImageToTextResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def image_to_text(
image: Annotated[
UploadFile, File(description="Uploaded image to transform with the pipeline.")
],
prompt: Annotated[
str,
Form(description="Text prompt(s) to guide transformation."),
] = "",
model_id: Annotated[
str,
Form(description="Hugging Face model ID used for transformation."),
] = "",
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

if file_exceeds_max_size(image, 50 * 1024 * 1024):
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=http_error("File size exceeds limit"),
)

image = Image.open(image.file).convert("RGB")
try:
return ImageToTextResponse(text=pipeline(prompt=prompt, image=image))
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ImageToTextPipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Image-to-text pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
6 changes: 6 additions & 0 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class LLMResponse(BaseModel):
tokens_used: int


class ImageToTextResponse(BaseModel):
"""Response model for text generation."""

text: str = Field(..., description="The generated text.")


class APIError(BaseModel):
"""API error response model."""

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ function download_all_models() {
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-text models.
huggingface-cli download Salesforce/blip-image-captioning-large --include "*.safetensors" "*.json" --cache-dir models

# Custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models
}
Expand Down
87 changes: 87 additions & 0 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,60 @@ paths:
security:
- HTTPBearer: []
x-speakeasy-name-override: llm
/image-to-text:
post:
tags:
- generate
summary: Image To Text
description: Transform image files to text.
operationId: genImageToText
requestBody:
content:
multipart/form-data:
schema:
$ref: '#/components/schemas/Body_genImageToText'
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/ImageToTextResponse'
x-speakeasy-name-override: data
'400':
description: Bad Request
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'401':
description: Unauthorized
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'413':
description: Request Entity Too Large
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'500':
description: Internal Server Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPError'
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
security:
- HTTPBearer: []
x-speakeasy-name-override: imageToText
components:
schemas:
APIError:
Expand Down Expand Up @@ -460,6 +514,28 @@ components:
- image
- model_id
title: Body_genImageToImage
Body_genImageToText:
properties:
image:
type: string
format: binary
title: Image
description: Uploaded image to transform with the pipeline.
prompt:
type: string
title: Prompt
description: Text prompt(s) to guide transformation.
default: ''
model_id:
type: string
title: Model Id
description: Hugging Face model ID used for transformation.
default: ''
type: object
required:
- image
- model_id
title: Body_genImageToText
Body_genImageToVideo:
properties:
image:
Expand Down Expand Up @@ -680,6 +756,17 @@ components:
- images
title: ImageResponse
description: Response model for image generation.
ImageToTextResponse:
properties:
text:
type: string
title: Text
description: The generated text.
type: object
required:
- text
title: ImageToTextResponse
description: Response model for text generation.
LLMResponse:
properties:
response:
Expand Down
4 changes: 3 additions & 1 deletion runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
segment_anything_2,
text_to_image,
upscale,
llm
llm,
image_to_text,
)
from fastapi.openapi.utils import get_openapi

Expand Down Expand Up @@ -125,6 +126,7 @@ def write_openapi(fname: str, entrypoint: str = "runner", version: str = "0.0.0"
app.include_router(audio_to_text.router)
app.include_router(segment_anything_2.router)
app.include_router(llm.router)
app.include_router(image_to_text.router)

logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...")
openapi = get_openapi(
Expand Down
Loading

0 comments on commit 7727d8b

Please sign in to comment.