Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] Complex image type handling #1801

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 98 additions & 87 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dsp.adapters.base_template import Field
from dspy.signatures.signature import Signature
from .base import Adapter
from .image_utils import encode_image, Image
from .image_utils import encode_image, Image, is_image

import ast
import json
Expand All @@ -15,7 +15,7 @@
from pydantic import TypeAdapter
from collections.abc import Mapping
from pydantic.fields import FieldInfo
from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
from typing import Dict, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
Expand Down Expand Up @@ -51,7 +51,6 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict

prepared_instructions = prepare_instructions(signature)
messages.append({"role": "system", "content": prepared_instructions})

for demo in demos:
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
Expand Down Expand Up @@ -161,6 +160,16 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) ->
The formatted value of the field, represented as a string.
"""
string_value = None

if field_info.annotation == Image and is_image(value):
print("value: ", value)
value = Image(url=encode_image(value))
# print("field info: ", field_info)
# if not isinstance(value, Image):
# print(f"Coerced image: {value}")
# coerced_image = Image(url=encode_image(value))
# print("post coerce: ", coerced_image)
# string_value = json.dumps(_serialize_for_json(coerced_image), ensure_ascii=False)
if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbered list for the LM.
string_value = format_input_list_field_value(value)
Expand All @@ -171,24 +180,27 @@ def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) ->

if assume_text:
return string_value
elif (isinstance(value, Image) or field_info.annotation == Image):
# This validation should happen somewhere else
# Safe to import PIL here because it's only imported when an image is actually being formatted
try:
import PIL
except ImportError:
raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
image_value = value
if not isinstance(image_value, Image):
if isinstance(image_value, dict) and "url" in image_value:
image_value = image_value["url"]
elif isinstance(image_value, str):
image_value = encode_image(image_value)
elif isinstance(image_value, PIL.Image.Image):
image_value = encode_image(image_value)
assert isinstance(image_value, str)
image_value = Image(url=image_value)
return {"type": "image_url", "image_url": image_value.model_dump()}

# What we actually want is that for any image inside of any arbitrary normal python or pudantic object, when we see it
# it will trigger some sort of escape sequence that we then combine at the end in order to make it a cohesive request to send to OAI
# Hooking too deep into the serialization process is a bad idea, but we need an escape hatch somewhere

# elif (isinstance(value, Image) or field_info.annotation == Image):
# # This validation should happen somewhere else
# # Safe to import PIL here because it's only imported when an image is actually being formatted
# try:
# import PIL
# except ImportError:
# raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
# image_value = value
# if not isinstance(image_value, Image):
# if isinstance(image_value, dict) and "url" in image_value:
# image_value = image_value["url"]
# elif isinstance(image_value, str) or isinstance(image_value, PIL.Image.Image):
# image_value = encode_image(image_value)
# assert isinstance(image_value, str)
# image_value = Image(url=image_value)
# return {"type": "image_url", "image_url": image_value.model_dump()}
else:
return {"type": "text", "text": string_value}

Expand Down Expand Up @@ -242,8 +254,7 @@ def parse_value(value, annotation):
return TypeAdapter(annotation).validate_python(parsed_value)


def format_turn(signature, values, role, incomplete=False):
fields_to_collapse = []
def format_turn(signature, values, role, incomplete=False):
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Expand All @@ -259,77 +270,77 @@ def format_turn(signature, values, role, incomplete=False):
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
"""
content = []

if role == "user":
fields: Dict[str, FieldInfo] = signature.input_fields
if incomplete:
fields_to_collapse.append({"type": "text", "text": "This is an example of the task, though some input or output fields are not supplied."})
fields = signature.input_fields
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
else:
fields: Dict[str, FieldInfo] = signature.output_fields
# Add the built-in field indicating that the chat turn has been completed
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
# Add the completed field for the assistant turn
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
field_names: KeysView = fields.keys()
if not incomplete:
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")

fields_to_collapse.extend(format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
},
assume_text=False
))
message_prefix = ""

if role == "user":
output_fields = list(signature.output_fields.keys())
def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""
if output_fields:
fields_to_collapse.append({
"type": "text",
"text": "Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
})

# flatmap the list if any items are lists otherwise keep the item
flattened_list = list(chain.from_iterable(
item if isinstance(item, list) else [item] for item in fields_to_collapse
))
if not incomplete and not set(values).issuperset(fields.keys()):
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")

if all(message.get("type", None) == "text" for message in flattened_list):
content = "\n\n".join(message.get("text") for message in flattened_list)
return {"role": role, "content": content}

# Collapse all consecutive text messages into a single message.
collapsed_messages = []
for item in flattened_list:
# First item is always added
if not collapsed_messages:
collapsed_messages.append(item)
continue

# If current item is image, add to collapsed_messages
if item.get("type") == "image_url":
if collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n"
collapsed_messages.append(item)
# If previous item is text and current item is text, append to previous item
elif collapsed_messages[-1].get("type") == "text":
collapsed_messages[-1]["text"] += "\n\n" + item["text"]
# If previous item is not text(aka image), add current item as a new item
else:
item["text"] = "\n\n" + item["text"]
collapsed_messages.append(item)
messages = []
if message_prefix:
messages.append({"type": "text", "text": message_prefix})

return {"role": role, "content": collapsed_messages}
field_messages = format_fields(
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
for k, v in fields.items()},
assume_text=False
)
messages.extend(field_messages)

# Add output field instructions for user messages
if role == "user" and signature.output_fields:
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
", and then ending with the marker for `[[ ## completed ## ]]`."
messages.append({"type": "text", "text": field_instructions})

# Process messages to handle image tags and collapse text
processed_messages = process_messages(messages)

if all(msg.get("type") == "text" for msg in processed_messages):
return {"role": role, "content": "\n\n".join(msg["text"] for msg in processed_messages)}
return {"role": role, "content": processed_messages}

def process_messages(messages):
"""Process messages to handle image tags and collapse consecutive text messages."""
processed = []
current_text = []

for msg in flatten_messages(messages):
if msg["type"] == "text":
# Handle image tags in text
parts = re.split(r'(<DSPY_IMAGE_START>.*?<DSPY_IMAGE_END>)', msg["text"])
for part in parts:
if match := re.match(r'<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>', part):
if current_text:
processed.append({"type": "text", "text": "\n\n".join(current_text)})
current_text = []
processed.append({"type": "image_url", "image_url": {"url": match.group(1)}})
elif part.strip():
current_text.append(part)
else:
if current_text:
processed.append({"type": "text", "text": "\n\n".join(current_text)})
current_text = []
processed.append(msg)

if current_text:
processed.append({"type": "text", "text": "\n\n".join(current_text)})

return processed

def flatten_messages(messages):
"""Flatten nested message lists."""
return list(chain.from_iterable(
item if isinstance(item, list) else [item] for item in messages
))

def get_annotation_name(annotation):
origin = get_origin(annotation)
Expand Down
21 changes: 13 additions & 8 deletions dspy/adapters/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
from urllib.parse import urlparse
import pydantic

from pydantic import model_serializer
try:
from PIL import Image as PILImage
PIL_AVAILABLE = True
Expand Down Expand Up @@ -41,8 +41,11 @@ def from_file(cls, file_path: str):

@classmethod
def from_PIL(cls, pil_image):
import PIL
return cls(url=encode_image(PIL.Image.open(pil_image)))
return cls(url=encode_image(pil_image))

@model_serializer()
def serialize_model(self):
return "<DSPY_IMAGE_START>" + self.url + "<DSPY_IMAGE_END>"

def is_url(string: str) -> bool:
"""Check if a string is a valid URL."""
Expand Down Expand Up @@ -85,6 +88,7 @@ def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_imag
return image
else:
# Unsupported string format
print(f"Unsupported image string: {image}")
raise ValueError(f"Unsupported image string: {image}")
elif PIL_AVAILABLE and isinstance(image, PILImage.Image):
# PIL Image
Expand All @@ -93,11 +97,12 @@ def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_imag
# Raw bytes
if not PIL_AVAILABLE:
raise ImportError("Pillow is required to process image bytes.")
img = Image.open(io.BytesIO(image))
img = PILImage.open(io.BytesIO(image))
return _encode_pil_image(img)
elif isinstance(image, Image):
return image.url
else:
print(f"Unsupported image type: {type(image)}")
raise ValueError(f"Unsupported image type: {type(image)}")

def _encode_image_from_file(file_path: str) -> str:
Expand All @@ -121,7 +126,7 @@ def _encode_image_from_url(image_url: str) -> str:
encoded_image = base64.b64encode(response.content).decode('utf-8')
return f"data:image/{file_extension};base64,{encoded_image}"

def _encode_pil_image(image: 'Image.Image') -> str:
def _encode_pil_image(image: 'PILImage.Image') -> str:
"""Encode a PIL Image object to a base64 data URI."""
buffered = io.BytesIO()
file_extension = (image.format or 'PNG').lower()
Expand All @@ -136,10 +141,10 @@ def _get_file_extension(path_or_url: str) -> str:

def is_image(obj) -> bool:
"""Check if the object is an image or a valid image reference."""
if PIL_AVAILABLE and isinstance(obj, Image.Image):
return True
if isinstance(obj, (bytes, bytearray)):
if PIL_AVAILABLE and isinstance(obj, PILImage.Image):
return True
# if isinstance(obj, (bytes, bytearray)):
# return True
if isinstance(obj, str):
if obj.startswith("data:image/"):
return True
Expand Down
41 changes: 28 additions & 13 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature, signature_to_template
from dspy.utils.callback import with_callbacks
from dspy.adapters.image_utils import Image

@lru_cache(maxsize=None)
def warn_once(msg: str):
Expand Down Expand Up @@ -47,11 +46,7 @@ def dump_state(self, save_verbose=None):

for field in demo:
# FIXME: Saving BaseModels as strings in examples doesn't matter because you never re-access as an object
# It does matter for images
if isinstance(demo[field], Image):
demo[field] = demo[field].model_dump()
elif isinstance(demo[field], BaseModel):
demo[field] = demo[field].model_dump_json()
demo[field] = serialize_object(demo[field])

state["demos"].append(demo)

Expand Down Expand Up @@ -89,13 +84,13 @@ def load_state(self, state, use_legacy_loading=False):
setattr(self, name, value)

# FIXME: Images are getting special treatment, but all basemodels initialized from json should be converted back to objects
for demo in self.demos:
for field in demo:
if isinstance(demo[field], dict) and "url" in demo[field]:
url = demo[field]["url"]
if not isinstance(url, str):
raise ValueError(f"Image URL must be a string, got {type(url)}")
demo[field] = Image(url=url)
# for demo in self.demos:
# for field in demo:
# if isinstance(demo[field], dict) and "url" in demo[field]:
# url = demo[field]["url"]
# if not isinstance(url, str):
# raise ValueError(f"Image URL must be a string, got {type(url)}")
# demo[field] = Image(url=url)

self.signature = self.signature.load_state(state["signature"])

Expand Down Expand Up @@ -296,6 +291,26 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
)

def serialize_object(obj):
"""
Recursively serialize a given object into a JSON-compatible format.
Supports Pydantic models, lists, dicts, and primitive types.
"""
if isinstance(obj, BaseModel):
# Use model_dump to convert the model into a JSON-serializable dict
return obj.model_dump_json()
elif isinstance(obj, list):
# Recursively process each item in the list
return [serialize_object(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(serialize_object(item) for item in obj)
elif isinstance(obj, dict):
# Recursively process each key-value pair in the dict
return {key: serialize_object(value) for key, value in obj.items()}
else:
# Assume the object is already JSON-compatible (e.g., int, str, float)
return obj

# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
Expand Down
2 changes: 2 additions & 0 deletions examples/vlm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
docs/
.byaldi/
Loading