Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multimodal #453

Merged
merged 19 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,106 @@
import os
from typing import Optional
from typing import Any, List, Optional, Sequence

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
from llama_index.core.query_engine import (
RetrieverQueryEngine,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.response_synthesizers.base import QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.tools.query_engine import QueryEngineTool

from app.settings import get_multi_modal_llm

def create_query_engine(index, **kwargs):

class MultiModalSynthesizer(TreeSummarize):
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
multimodal_model: Optional[MultiModalLLM] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._text_qa_template = text_qa_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL

leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(await super().asynthesize(query, nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved


def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.

Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
return index.as_query_engine(**kwargs)
if index.__class__.__name__ == "LlamaCloudIndex":
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
multi_modal_llm = get_multi_modal_llm()
if multi_modal_llm:
kwargs["retrieve_image_nodes"] = True
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
response_synthesizer=MultiModalSynthesizer(
multimodal_model=multi_modal_llm
),
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Multi-modal LLM is not set")
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved


def get_query_engine_tool(
Expand Down
20 changes: 18 additions & 2 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import os
from typing import Dict
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

# `Settings` does not support setting `MultiModalLLM`
# so we use a global variable to store it
_multi_modal_llm: Optional[MultiModalLLM] = None


def get_multi_modal_llm():
return _multi_modal_llm


def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -60,14 +69,21 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS

max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)

if model_name in GPT4V_MODELS:
global _multi_modal_llm
_multi_modal_llm = OpenAIMultiModal(model=model_name)
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
Expand Down
Loading