Skip to content

Commit

Permalink
fix: update type hints for Python 3.9 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] and jxnl committed Dec 15, 2024
1 parent ea5a035 commit 34728a7
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions instructor/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from collections.abc import Mapping
from functools import lru_cache, cache
from pathlib import Path
from typing import (
Any, Callable, Literal, Optional, TypeVar, TypedDict, Union,
cast, ClassVar
)
from typing import Any, Callable, Literal, Optional, TypeVar, TypedDict, ClassVar
from urllib.parse import urlparse

import requests
Expand All @@ -32,13 +29,7 @@
GuessTypeResult = tuple[Optional[str], Optional[str]]
StrSplitResult = list[str]
StrSplitMethod = Callable[[str, Optional[int]], StrSplitResult]
str.split = cast(StrSplitMethod, str.split) # type: ignore

# Add type hints with ignore comments for built-in functions
mimetypes.guess_type = cast(Callable[[str], GuessTypeResult], mimetypes.guess_type) # type: ignore
imghdr.what = cast(Callable[[Optional[str], bytes], Optional[str]], imghdr.what) # type: ignore
base64.b64decode = cast(Callable[[Union[str, bytes]], bytes], base64.b64decode) # type: ignore
re.match = cast(Callable[[str, str], Optional[re.Match[str]]], re.match) # type: ignore

class ImageParamsBase(TypedDict):
type: Literal["image"]
Expand All @@ -48,16 +39,19 @@ class ImageParamsBase(TypedDict):
class ImageParams(ImageParamsBase, total=False):
cache_control: CacheControlType


class Image(BaseModel):
VALID_MIME_TYPES: ClassVar[list[str]] = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp"
"image/webp",
]
source: Union[str, Path] = Field(description="URL, file path, or base64 data of the image")
source: str | Path = Field(
description="URL, file path, or base64 data of the image"
)
media_type: str = Field(description="MIME type of the image")
data: Optional[str] = Field(None, description="Base64 encoded image data", repr=False)
data: str | None = Field(None, description="Base64 encoded image data", repr=False)

@classmethod
def autodetect(cls, source: str | Path) -> Image | None:
Expand Down Expand Up @@ -89,9 +83,7 @@ def autodetect(cls, source: str | Path) -> Image | None:
return None

@classmethod
def autodetect_safely(
cls, source: str | Path
) -> Image | str:
def autodetect_safely(cls, source: str | Path) -> Image | str:
"""Safely attempt to autodetect an image from a source string or path.
Args:
Expand Down Expand Up @@ -145,21 +137,22 @@ def from_raw_base64(cls, data: str) -> Image | None:
"""
try:
decoded: bytes = base64.b64decode(data)
img_type: Optional[str] = imghdr.what(None, decoded)
img_type: str | None = imghdr.what(None, decoded)
if img_type:
media_type = mimetypes.guess_type(data)[0]
if media_type in cls.VALID_MIME_TYPES:
return cls(source=data, media_type=media_type, data=data)
except Exception:
pass
return None

@classmethod
@cache # Use cache instead of lru_cache to avoid memory leaks
def from_url(cls, url: str) -> Image:
if cls.is_base64(url):
return cls.from_base64(url)
parsed_url = urlparse(url)
media_type: Optional[str] = mimetypes.guess_type(parsed_url.path)[0]
media_type: str | None = mimetypes.guess_type(parsed_url.path)[0]

if not media_type:
try:
Expand All @@ -183,12 +176,16 @@ def from_path(cls, path: str | Path) -> Image:
raise ValueError("Image file is empty")

if path.stat().st_size > MAX_MISTRAL_IMAGE_SIZE:
raise ValueError(f"Image file size ({path.stat().st_size / 1024 / 1024:.1f}MB) "
f"exceeds Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB")
media_type: Optional[str] = mimetypes.guess_type(str(path))[0]
raise ValueError(
f"Image file size ({path.stat().st_size / 1024 / 1024:.1f}MB) "
f"exceeds Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB"
)
media_type: str | None = mimetypes.guess_type(str(path))[0]
if media_type not in VALID_MISTRAL_MIME_TYPES:
raise ValueError(f"Unsupported image format: {media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}")
raise ValueError(
f"Unsupported image format: {media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}"
)

data = base64.b64encode(path.read_bytes()).decode("utf-8")
return cls(source=path, media_type=media_type, data=data)
Expand Down Expand Up @@ -246,16 +243,20 @@ def to_mistral(self) -> dict[str, Any]:
"""
# Validate media type
if self.media_type not in VALID_MISTRAL_MIME_TYPES:
raise ValueError(f"Unsupported image format for Mistral: {self.media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}")
raise ValueError(
f"Unsupported image format for Mistral: {self.media_type}. "
f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}"
)

# For base64 data, validate size
if self.data:
# Calculate size of decoded base64 data
data_size = len(base64.b64decode(self.data))
if data_size > MAX_MISTRAL_IMAGE_SIZE:
raise ValueError(f"Image size ({data_size / 1024 / 1024:.1f}MB) exceeds "
f"Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB")
raise ValueError(
f"Image size ({data_size / 1024 / 1024:.1f}MB) exceeds "
f"Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB"
)

if (
isinstance(self.source, str)
Expand All @@ -267,7 +268,7 @@ def to_mistral(self) -> dict[str, Any]:
data = self.data or str(self.source).split(",", 1)[1]
return {
"type": "image_url",
"data": f"data:{self.media_type};base64,{data}"
"data": f"data:{self.media_type};base64,{data}",
}
else:
raise ValueError("Image data is missing for base64 encoding.")
Expand All @@ -276,12 +277,8 @@ def to_mistral(self) -> dict[str, Any]:
class Audio(BaseModel):
"""Represents an audio that can be loaded from a URL or file path."""

source: str | Path = Field(
description="URL or file path of the audio"
)
data: str | None = Field(
None, description="Base64 encoded audio data", repr=False
)
source: str | Path = Field(description="URL or file path of the audio")
data: str | None = Field(None, description="Base64 encoded audio data", repr=False)


class ImageWithCacheControl(Image):
Expand Down Expand Up @@ -325,11 +322,11 @@ def to_anthropic(self) -> dict[str, Any]:


def convert_contents(
contents: Union[str, Image, dict[str, Any], list[Union[str, Image, dict[str, Any]]]],
contents: str | Image | dict[str, Any] | list[str | Image | dict[str, Any]],
mode: Mode,
*, # Make autodetect_images keyword-only since it's unused
_autodetect_images: bool = True, # Prefix with _ to indicate intentionally unused
) -> Union[str, list[dict[str, Any]]]:
) -> str | list[dict[str, Any]]:
"""Convert contents to the appropriate format for the given mode."""
# Handle single string case
if isinstance(contents, str):
Expand Down

0 comments on commit 34728a7

Please sign in to comment.