From f9a057dddee618747d95e9cbe3dddbc2359b634f Mon Sep 17 00:00:00 2001
From: Huu Le <39040748+leehuwuj@users.noreply.github.com>
Date: Fri, 29 Nov 2024 18:02:14 +0700
Subject: [PATCH] feat: add support for multimodal indexes (#453)

---------
Co-authored-by: thucpn <thucsh2@gmail.com>
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
---
 .changeset/blue-hornets-boil.md               |   5 +
 .../python/agent/tools/query_engine.py        | 148 +++++++++++++++++-
 .../components/settings/python/settings.py    |  20 ++-
 3 files changed, 164 insertions(+), 9 deletions(-)
 create mode 100644 .changeset/blue-hornets-boil.md

diff --git a/.changeset/blue-hornets-boil.md b/.changeset/blue-hornets-boil.md
new file mode 100644
index 000000000..e8c2928d4
--- /dev/null
+++ b/.changeset/blue-hornets-boil.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add support multimodal indexes (e.g. from LlamaCloud)
diff --git a/templates/components/engines/python/agent/tools/query_engine.py b/templates/components/engines/python/agent/tools/query_engine.py
index e78ae0442..396fb1d6e 100644
--- a/templates/components/engines/python/agent/tools/query_engine.py
+++ b/templates/components/engines/python/agent/tools/query_engine.py
@@ -1,10 +1,27 @@
 import os
-from typing import Optional
+from typing import Any, Dict, List, Optional, Sequence
 
+from llama_index.core import get_response_synthesizer
+from llama_index.core.base.base_query_engine import BaseQueryEngine
+from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
+from llama_index.core.multi_modal_llms import MultiModalLLM
+from llama_index.core.prompts.base import BasePromptTemplate
+from llama_index.core.prompts.default_prompt_selectors import (
+    DEFAULT_TEXT_QA_PROMPT_SEL,
+)
+from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
+from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
+from llama_index.core.schema import (
+    ImageNode,
+    NodeWithScore,
+)
 from llama_index.core.tools.query_engine import QueryEngineTool
+from llama_index.core.types import RESPONSE_TEXT_TYPE
 
+from app.settings import get_multi_modal_llm
 
-def create_query_engine(index, **kwargs):
+
+def create_query_engine(index, **kwargs) -> BaseQueryEngine:
     """
     Create a query engine for the given index.
 
@@ -12,16 +29,23 @@ def create_query_engine(index, **kwargs):
         index: The index to create a query engine for.
         params (optional): Additional parameters for the query engine, e.g: similarity_top_k
     """
+
     top_k = int(os.getenv("TOP_K", 0))
     if top_k != 0 and kwargs.get("filters") is None:
         kwargs["similarity_top_k"] = top_k
+    multimodal_llm = get_multi_modal_llm()
+    if multimodal_llm:
+        kwargs["response_synthesizer"] = MultiModalSynthesizer(
+            multimodal_model=multimodal_llm,
+        )
+
     # If index is index is LlamaCloudIndex
     # use auto_routed mode for better query results
-    if (
-        index.__class__.__name__ == "LlamaCloudIndex"
-        and kwargs.get("auto_routed") is None
-    ):
-        kwargs["auto_routed"] = True
+    if index.__class__.__name__ == "LlamaCloudIndex":
+        if kwargs.get("retrieval_mode") is None:
+            kwargs["retrieval_mode"] = "auto_routed"
+        if multimodal_llm:
+            kwargs["retrieve_image_nodes"] = True
     return index.as_query_engine(**kwargs)
 
 
@@ -51,3 +75,113 @@ def get_query_engine_tool(
         name=name,
         description=description,
     )
+
+
+class MultiModalSynthesizer(BaseSynthesizer):
+    """
+    A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
+    """
+
+    def __init__(
+        self,
+        multimodal_model: MultiModalLLM,
+        response_synthesizer: Optional[BaseSynthesizer] = None,
+        text_qa_template: Optional[BasePromptTemplate] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+        self._multi_modal_llm = multimodal_model
+        self._response_synthesizer = response_synthesizer or get_response_synthesizer()
+        self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL
+
+    def _get_prompts(self, **kwargs) -> Dict[str, Any]:
+        return {
+            "text_qa_template": self._text_qa_template,
+        }
+
+    def _update_prompts(self, prompts: Dict[str, Any]) -> None:
+        if "text_qa_template" in prompts:
+            self._text_qa_template = prompts["text_qa_template"]
+
+    async def aget_response(
+        self,
+        *args,
+        **response_kwargs: Any,
+    ) -> RESPONSE_TEXT_TYPE:
+        return await self._response_synthesizer.aget_response(*args, **response_kwargs)
+
+    def get_response(self, *args, **kwargs) -> RESPONSE_TEXT_TYPE:
+        return self._response_synthesizer.get_response(*args, **kwargs)
+
+    async def asynthesize(
+        self,
+        query: QueryTextType,
+        nodes: List[NodeWithScore],
+        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
+        **response_kwargs: Any,
+    ) -> RESPONSE_TYPE:
+        image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
+
+        if len(image_nodes) == 0:
+            return await self._response_synthesizer.asynthesize(query, text_nodes)
+
+        # Summarize the text nodes to avoid exceeding the token limit
+        text_response = str(
+            await self._response_synthesizer.asynthesize(query, text_nodes)
+        )
+
+        fmt_prompt = self._text_qa_template.format(
+            context_str=text_response,
+            query_str=query.query_str,  # type: ignore
+        )
+
+        llm_response = await self._multi_modal_llm.acomplete(
+            prompt=fmt_prompt,
+            image_documents=[
+                image_node.node
+                for image_node in image_nodes
+                if isinstance(image_node.node, ImageNode)
+            ],
+        )
+
+        return Response(
+            response=str(llm_response),
+            source_nodes=nodes,
+            metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
+        )
+
+    def synthesize(
+        self,
+        query: QueryTextType,
+        nodes: List[NodeWithScore],
+        additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
+        **response_kwargs: Any,
+    ) -> RESPONSE_TYPE:
+        image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
+
+        if len(image_nodes) == 0:
+            return self._response_synthesizer.synthesize(query, text_nodes)
+
+        # Summarize the text nodes to avoid exceeding the token limit
+        text_response = str(self._response_synthesizer.synthesize(query, text_nodes))
+
+        fmt_prompt = self._text_qa_template.format(
+            context_str=text_response,
+            query_str=query.query_str,  # type: ignore
+        )
+
+        llm_response = self._multi_modal_llm.complete(
+            prompt=fmt_prompt,
+            image_documents=[
+                image_node.node
+                for image_node in image_nodes
+                if isinstance(image_node.node, ImageNode)
+            ],
+        )
+
+        return Response(
+            response=str(llm_response),
+            source_nodes=nodes,
+            metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
+        )
diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py
index bc7270bd8..ff647560c 100644
--- a/templates/components/settings/python/settings.py
+++ b/templates/components/settings/python/settings.py
@@ -1,8 +1,17 @@
 import os
-from typing import Dict
+from typing import Dict, Optional
 
+from llama_index.core.multi_modal_llms import MultiModalLLM
 from llama_index.core.settings import Settings
 
+# `Settings` does not support setting `MultiModalLLM`
+# so we use a global variable to store it
+_multi_modal_llm: Optional[MultiModalLLM] = None
+
+
+def get_multi_modal_llm():
+    return _multi_modal_llm
+
 
 def init_settings():
     model_provider = os.getenv("MODEL_PROVIDER")
@@ -60,14 +69,21 @@ def init_openai():
     from llama_index.core.constants import DEFAULT_TEMPERATURE
     from llama_index.embeddings.openai import OpenAIEmbedding
     from llama_index.llms.openai import OpenAI
+    from llama_index.multi_modal_llms.openai import OpenAIMultiModal
+    from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS
 
     max_tokens = os.getenv("LLM_MAX_TOKENS")
+    model_name = os.getenv("MODEL", "gpt-4o-mini")
     Settings.llm = OpenAI(
-        model=os.getenv("MODEL", "gpt-4o-mini"),
+        model=model_name,
         temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
         max_tokens=int(max_tokens) if max_tokens is not None else None,
     )
 
+    if model_name in GPT4V_MODELS:
+        global _multi_modal_llm
+        _multi_modal_llm = OpenAIMultiModal(model=model_name)
+
     dimensions = os.getenv("EMBEDDING_DIM")
     Settings.embed_model = OpenAIEmbedding(
         model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),