Skip to content

Commit

Permalink
feat: support multimodal (#1045)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Oct 6, 2024
1 parent 5e6ba05 commit d968970
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 8 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}

- name: Run Gemini Tests
if: matrix.python-version != '3.9'
run: poetry run pytest tests/llm/test_gemini
env:
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}

- name: Generate coverage report
if: matrix.python-version == '3.11'
run: |
Expand Down
39 changes: 39 additions & 0 deletions docs/concepts/multimodal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Multimodal

Instructor supports multimodal interactions by providing helper classes that are automatically converted to the correct format for different providers, allowing you to work with both text and images in your prompts and responses. This functionality is implemented in the `multimodal.py` module and provides a seamless way to handle images alongside text for various AI models.

## `Image`

The core of multimodal support in Instructor is the `Image` class. This class represents an image that can be loaded from a URL or file path. It provides methods to create `Image` instances and convert them to formats compatible with different AI providers.

It's important to note that Anthropic and OpenAI have different formats for handling images in their API requests. The `Image` class in Instructor abstracts away these differences, allowing you to work with a unified interface.

### Usage

You can create an `Image` instance from a URL or file path using the `from_url` or `from_path` methods. The `Image` class will automatically convert the image to a base64-encoded string and include it in the API request.

```python
import instructor
import openai

image1 = instructor.Image.from_url("https://example.com/image.jpg")
image2 = instructor.Image.from_path("path/to/image.jpg")

client = instructor.from_openai(openai.OpenAI())

response = client.chat.completions.create(
model="gpt-4o-mini",
response_model=ImageAnalyzer,
messages=[
{"role": "user", "content": [
"What is in this two images?",
image1,
image2
]}
]
)
```

The `Image` class takes care of the necessary conversions and formatting, ensuring that your code remains clean and provider-agnostic. This flexibility is particularly valuable when you're experimenting with different models or when you need to switch providers based on specific project requirements.

By leveraging Instructor's multimodal capabilities, you can focus on building your application logic without worrying about the intricacies of each provider's image handling format. This not only saves development time but also makes your code more maintainable and adaptable to future changes in AI provider APIs.
2 changes: 2 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .mode import Mode
from .process_response import handle_response_model
from .distil import FinetuneFormat, Instructions
from .multimodal import Image
from .dsl import (
CitationMixin,
Maybe,
Expand All @@ -24,6 +25,7 @@

__all__ = [
"Instructor",
"Image",
"from_openai",
"from_litellm",
"AsyncInstructor",
Expand Down
117 changes: 117 additions & 0 deletions instructor/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations
import base64
from typing import Any, Union
from pathlib import Path
from pydantic import BaseModel, Field
from .mode import Mode


class Image(BaseModel):
"""Represents an image that can be loaded from a URL or file path."""

source: Union[str, Path] = Field(..., description="URL or file path of the image") # noqa: UP007
media_type: str = Field(..., description="MIME type of the image")
data: Union[str, None] = Field( # noqa: UP007
None, description="Base64 encoded image data", repr=False
)

@classmethod
def from_url(cls, url: str) -> Image:
"""Create an Image instance from a URL."""
return cls(source=url, media_type="image/jpeg", data=None)

@classmethod
def from_path(cls, path: str | Path) -> Image:
"""Create an Image instance from a file path."""
path = Path(path)
if not path.is_file():
raise FileNotFoundError(f"Image file not found: {path}")

suffix = path.suffix.lower().lstrip(".")
if suffix not in ["jpeg", "jpg", "png"]:
raise ValueError(f"Unsupported image format: {suffix}")

if path.stat().st_size == 0:
raise ValueError("Image file is empty")

media_type = "image/jpeg" if suffix in ["jpeg", "jpg"] else "image/png"
data = base64.b64encode(path.read_bytes()).decode("utf-8")
return cls(source=str(path), media_type=media_type, data=data)

def to_anthropic(self) -> dict[str, Any]:
"""Convert the Image instance to Anthropic's API format."""
if isinstance(self.source, str) and self.source.startswith(
("http://", "https://")
):
import requests

response = requests.get(self.source)
response.raise_for_status()
self.data = base64.b64encode(response.content).decode("utf-8")
self.media_type = response.headers.get("Content-Type", "image/jpeg")

return {
"type": "image",
"source": {
"type": "base64",
"media_type": self.media_type,
"data": self.data,
},
}

def to_openai(self) -> dict[str, Any]:
"""Convert the Image instance to OpenAI's Vision API format."""
if isinstance(self.source, str) and self.source.startswith(
("http://", "https://")
):
return {"type": "image_url", "image_url": {"url": self.source}}
elif self.data:
return {
"type": "image_url",
"image_url": {"url": f"data:{self.media_type};base64,{self.data}"},
}
else:
raise ValueError("Image data is missing for base64 encoding.")


def convert_contents(
contents: Union[list[Union[str, Image]], str, Image], # noqa: UP007
mode: Mode,
) -> Union[str, list[dict[str, Any]]]: # noqa: UP007
"""Convert content items to the appropriate format based on the specified mode."""
if isinstance(contents, str):
return contents
if isinstance(contents, Image):
contents = [contents]

converted_contents: list[dict[str, Union[str, Image]]] = [] # noqa: UP007
for content in contents:
if isinstance(content, str):
converted_contents.append({"type": "text", "text": content})
elif isinstance(content, Image):
if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}:
converted_contents.append(content.to_anthropic())
elif mode in {Mode.GEMINI_JSON, Mode.GEMINI_TOOLS}:
raise NotImplementedError("Gemini is not supported yet")
else:
converted_contents.append(content.to_openai())
else:
raise ValueError(f"Unsupported content type: {type(content)}")
return converted_contents


def convert_messages(
messages: list[dict[str, Union[str, list[Union[str, Image]]]]], # noqa: UP007
mode: Mode,
) -> list[dict[str, Any]]:
"""Convert messages to the appropriate format based on the specified mode."""
converted_messages = []
for message in messages:
role = message["role"]
content = message["content"]
if isinstance(content, str):
converted_messages.append({"role": role, "content": content}) # type: ignore
else:
converted_content = convert_contents(content, mode)
converted_messages.append({"role": role, "content": converted_content}) # type: ignore
return converted_messages # type: ignore
5 changes: 4 additions & 1 deletion instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from instructor.dsl.simple_type import AdapterBase, ModelAdapter, is_simple_type
from instructor.function_calls import OpenAISchema, openai_schema
from instructor.utils import merge_consecutive_messages

from instructor.multimodal import convert_messages

logger = logging.getLogger("instructor")

Expand Down Expand Up @@ -664,6 +664,9 @@ def handle_response_model(
else:
raise ValueError(f"Invalid patch mode: {mode}")

if "messages" in new_kwargs:
new_kwargs["messages"] = convert_messages(new_kwargs["messages"], mode)

logger.debug(
f"Instructor Request: {mode.value=}, {response_model=}, {new_kwargs=}",
extra={
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ nav:
- Advanced Usage of Mistral Large: 'hub/mistral.md'
- Generating Knowledge Graphs with AI: 'hub/knowledge_graph.md'
- Extracting Relevant Clips from YouTube Videos: "hub/youtube_clips.md"
- Tutorial: Building Knowledge Graphs: 'tutorials/5-knowledge-graphs.ipynb'
- Building Knowledge Graphs: 'tutorials/5-knowledge-graphs.ipynb'
- CLI Reference:
- "CLI Reference": "cli/index.md"
- "Finetuning GPT-3.5": "cli/finetune.md"
Expand Down
44 changes: 44 additions & 0 deletions tests/llm/test_openai/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from instructor.multimodal import Image
import instructor
from pydantic import Field, BaseModel
from itertools import product
from .util import models, modes


@pytest.mark.parametrize("model, mode", product(models, modes))
def test_multimodal_image_description(model, mode, client):
client = instructor.patch(client, mode=mode)

class ImageDescription(BaseModel):
objects: list[str] = Field(..., description="The objects in the image")
scene: str = Field(..., description="The scene of the image")
colors: list[str] = Field(..., description="The colors in the image")

response = client.chat.completions.create(
model=model, # Ensure this is a vision-capable model
response_model=ImageDescription,
messages=[
{
"role": "system",
"content": "You are a helpful assistant that can describe images",
},
{
"role": "user",
"content": [
"What is this?",
Image.from_url(
"https://pbs.twimg.com/profile_images/1816950591857233920/ZBxrWCbX_400x400.jpg"
),
],
},
],
)

# Assertions to validate the response
assert isinstance(response, ImageDescription)
assert len(response.objects) > 0
assert response.scene != ""
assert len(response.colors) > 0

# Additional assertions can be added based on expected content of the sample image
Loading

0 comments on commit d968970

Please sign in to comment.