Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multimodal #453

Merged
merged 19 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 118 additions & 7 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,138 @@
import os
from typing import Optional
from typing import Any, List, Optional, Sequence

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
from llama_index.core.query_engine import (
RetrieverQueryEngine,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.response_synthesizers.base import QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.tools.query_engine import QueryEngineTool

from app.settings import multi_modal_llm

def create_query_engine(index, **kwargs):

class MultiModalSynthesizer(TreeSummarize):
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
multimodal_model: Optional[MultiModalLLM] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._text_qa_template = text_qa_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL

leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(super().synthesize(query, nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)

async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(await super().asynthesize(query, nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved


def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.

Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
if index.__class__.__name__ == "LlamaCloudIndex":
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
mm_model = multi_modal_llm.get()
if mm_model:
kwargs["retrieve_image_nodes"] = True
print("Using multi-modal model")
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
response_synthesizer=MultiModalSynthesizer(
multimodal_model=mm_model
),
)

return index.as_query_engine(**kwargs)


Expand Down
16 changes: 14 additions & 2 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
from typing import Dict
from contextvars import ContextVar
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

multi_modal_llm: ContextVar[Optional[MultiModalLLM]] = ContextVar(
"multi_modal_llm", default=None
)


def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -60,14 +66,20 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS

max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)

if model_name in GPT4V_MODELS:
multi_modal_llm.set(OpenAIMultiModal(model=model_name))

leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
Expand Down