Skip to content

Commit

Permalink
💬 fix image history for non vision
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Dec 28, 2023
1 parent 5551b30 commit f2c1323
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/funcchain/chain/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def create_chain(
)

# for vision models
images = _handle_images(llm, input_kwargs)
images = _handle_images(llm, memory, input_kwargs)

# create prompts
instruction_prompt = create_instruction_prompt(
Expand Down Expand Up @@ -214,6 +214,7 @@ def _crop_large_inputs(

def _handle_images(
llm: BaseChatModel,
memory: BaseChatMessageHistory,
input_kwargs: dict[str, str],
) -> list[Image.Image]:
"""
Expand All @@ -226,6 +227,9 @@ def _handle_images(
del input_kwargs[k]
elif images:
raise RuntimeError("Images as input are only supported for vision models.")
elif _history_contains_images(memory):
print("Warning: Images in chat history are ignored for non-vision models.")
memory.messages = _clear_images_from_history(memory.messages)

return images

Expand Down Expand Up @@ -291,3 +295,27 @@ def _add_custom_callbacks(
llm.callbacks = callbacks

return llm


def _history_contains_images(history: BaseChatMessageHistory) -> bool:
"""
Check if the chat history contains images.
"""
for message in history.messages:
if isinstance(message.content, list):
for content in message.content:
if isinstance(content, dict) and content.get("type") == "image_url":
return True
return False


def _clear_images_from_history(history: list[BaseMessage]) -> list[BaseMessage]:
"""
Remove images from the chat history.
"""
for message in history:
if isinstance(message.content, list):
for content in message.content:
if isinstance(content, dict) and content.get("type") == "image_url":
message.content.remove(content)
return history

0 comments on commit f2c1323

Please sign in to comment.