From f9a057dddee618747d95e9cbe3dddbc2359b634f Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Fri, 29 Nov 2024 18:02:14 +0700 Subject: [PATCH] feat: add support for multimodal indexes (#453) --------- Co-authored-by: thucpn <thucsh2@gmail.com> Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .changeset/blue-hornets-boil.md | 5 + .../python/agent/tools/query_engine.py | 148 +++++++++++++++++- .../components/settings/python/settings.py | 20 ++- 3 files changed, 164 insertions(+), 9 deletions(-) create mode 100644 .changeset/blue-hornets-boil.md diff --git a/.changeset/blue-hornets-boil.md b/.changeset/blue-hornets-boil.md new file mode 100644 index 000000000..e8c2928d4 --- /dev/null +++ b/.changeset/blue-hornets-boil.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add support multimodal indexes (e.g. from LlamaCloud) diff --git a/templates/components/engines/python/agent/tools/query_engine.py b/templates/components/engines/python/agent/tools/query_engine.py index e78ae0442..396fb1d6e 100644 --- a/templates/components/engines/python/agent/tools/query_engine.py +++ b/templates/components/engines/python/agent/tools/query_engine.py @@ -1,10 +1,27 @@ import os -from typing import Optional +from typing import Any, Dict, List, Optional, Sequence +from llama_index.core import get_response_synthesizer +from llama_index.core.base.base_query_engine import BaseQueryEngine +from llama_index.core.base.response.schema import RESPONSE_TYPE, Response +from llama_index.core.multi_modal_llms import MultiModalLLM +from llama_index.core.prompts.base import BasePromptTemplate +from llama_index.core.prompts.default_prompt_selectors import ( + DEFAULT_TEXT_QA_PROMPT_SEL, +) +from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes +from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType +from llama_index.core.schema import ( + ImageNode, + NodeWithScore, +) from llama_index.core.tools.query_engine import QueryEngineTool +from llama_index.core.types import RESPONSE_TEXT_TYPE +from app.settings import get_multi_modal_llm -def create_query_engine(index, **kwargs): + +def create_query_engine(index, **kwargs) -> BaseQueryEngine: """ Create a query engine for the given index. @@ -12,16 +29,23 @@ def create_query_engine(index, **kwargs): index: The index to create a query engine for. params (optional): Additional parameters for the query engine, e.g: similarity_top_k """ + top_k = int(os.getenv("TOP_K", 0)) if top_k != 0 and kwargs.get("filters") is None: kwargs["similarity_top_k"] = top_k + multimodal_llm = get_multi_modal_llm() + if multimodal_llm: + kwargs["response_synthesizer"] = MultiModalSynthesizer( + multimodal_model=multimodal_llm, + ) + # If index is index is LlamaCloudIndex # use auto_routed mode for better query results - if ( - index.__class__.__name__ == "LlamaCloudIndex" - and kwargs.get("auto_routed") is None - ): - kwargs["auto_routed"] = True + if index.__class__.__name__ == "LlamaCloudIndex": + if kwargs.get("retrieval_mode") is None: + kwargs["retrieval_mode"] = "auto_routed" + if multimodal_llm: + kwargs["retrieve_image_nodes"] = True return index.as_query_engine(**kwargs) @@ -51,3 +75,113 @@ def get_query_engine_tool( name=name, description=description, ) + + +class MultiModalSynthesizer(BaseSynthesizer): + """ + A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response. + """ + + def __init__( + self, + multimodal_model: MultiModalLLM, + response_synthesizer: Optional[BaseSynthesizer] = None, + text_qa_template: Optional[BasePromptTemplate] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._multi_modal_llm = multimodal_model + self._response_synthesizer = response_synthesizer or get_response_synthesizer() + self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL + + def _get_prompts(self, **kwargs) -> Dict[str, Any]: + return { + "text_qa_template": self._text_qa_template, + } + + def _update_prompts(self, prompts: Dict[str, Any]) -> None: + if "text_qa_template" in prompts: + self._text_qa_template = prompts["text_qa_template"] + + async def aget_response( + self, + *args, + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + return await self._response_synthesizer.aget_response(*args, **response_kwargs) + + def get_response(self, *args, **kwargs) -> RESPONSE_TEXT_TYPE: + return self._response_synthesizer.get_response(*args, **kwargs) + + async def asynthesize( + self, + query: QueryTextType, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + **response_kwargs: Any, + ) -> RESPONSE_TYPE: + image_nodes, text_nodes = _get_image_and_text_nodes(nodes) + + if len(image_nodes) == 0: + return await self._response_synthesizer.asynthesize(query, text_nodes) + + # Summarize the text nodes to avoid exceeding the token limit + text_response = str( + await self._response_synthesizer.asynthesize(query, text_nodes) + ) + + fmt_prompt = self._text_qa_template.format( + context_str=text_response, + query_str=query.query_str, # type: ignore + ) + + llm_response = await self._multi_modal_llm.acomplete( + prompt=fmt_prompt, + image_documents=[ + image_node.node + for image_node in image_nodes + if isinstance(image_node.node, ImageNode) + ], + ) + + return Response( + response=str(llm_response), + source_nodes=nodes, + metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}, + ) + + def synthesize( + self, + query: QueryTextType, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + **response_kwargs: Any, + ) -> RESPONSE_TYPE: + image_nodes, text_nodes = _get_image_and_text_nodes(nodes) + + if len(image_nodes) == 0: + return self._response_synthesizer.synthesize(query, text_nodes) + + # Summarize the text nodes to avoid exceeding the token limit + text_response = str(self._response_synthesizer.synthesize(query, text_nodes)) + + fmt_prompt = self._text_qa_template.format( + context_str=text_response, + query_str=query.query_str, # type: ignore + ) + + llm_response = self._multi_modal_llm.complete( + prompt=fmt_prompt, + image_documents=[ + image_node.node + for image_node in image_nodes + if isinstance(image_node.node, ImageNode) + ], + ) + + return Response( + response=str(llm_response), + source_nodes=nodes, + metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}, + ) diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py index bc7270bd8..ff647560c 100644 --- a/templates/components/settings/python/settings.py +++ b/templates/components/settings/python/settings.py @@ -1,8 +1,17 @@ import os -from typing import Dict +from typing import Dict, Optional +from llama_index.core.multi_modal_llms import MultiModalLLM from llama_index.core.settings import Settings +# `Settings` does not support setting `MultiModalLLM` +# so we use a global variable to store it +_multi_modal_llm: Optional[MultiModalLLM] = None + + +def get_multi_modal_llm(): + return _multi_modal_llm + def init_settings(): model_provider = os.getenv("MODEL_PROVIDER") @@ -60,14 +69,21 @@ def init_openai(): from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI + from llama_index.multi_modal_llms.openai import OpenAIMultiModal + from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS max_tokens = os.getenv("LLM_MAX_TOKENS") + model_name = os.getenv("MODEL", "gpt-4o-mini") Settings.llm = OpenAI( - model=os.getenv("MODEL", "gpt-4o-mini"), + model=model_name, temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), max_tokens=int(max_tokens) if max_tokens is not None else None, ) + if model_name in GPT4V_MODELS: + global _multi_modal_llm + _multi_modal_llm = OpenAIMultiModal(model=model_name) + dimensions = os.getenv("EMBEDDING_DIM") Settings.embed_model = OpenAIEmbedding( model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),