Skip to content

Commit

Permalink
improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Nov 29, 2024
1 parent b9d336c commit 235159e
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os
from typing import Any, List, Optional, Sequence

from llama_index.core import get_response_synthesizer
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.query_engine import (
RetrieverQueryEngine,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.response_synthesizers.base import QueryTextType
from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
Expand All @@ -23,21 +23,23 @@
from app.settings import get_multi_modal_llm


class MultiModalSynthesizer(TreeSummarize):
class MultiModalSynthesizer(BaseSynthesizer):
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
multimodal_model: Optional[MultiModalLLM] = None,
multimodal_model: MultiModalLLM,
response_synthesizer: Optional[BaseSynthesizer],
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._text_qa_template = text_qa_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL
self._response_synthesizer = response_synthesizer
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL

async def asynthesize(
self,
Expand All @@ -49,11 +51,13 @@ async def asynthesize(
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(await super().asynthesize(query, nodes))
text_response = str(
await self._response_synthesizer.asynthesize(query, text_nodes)
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str,
query_str=query.query_str, # type: ignore
)

llm_response = await self._multi_modal_llm.acomplete(
Expand Down Expand Up @@ -90,13 +94,13 @@ def create_query_engine(index, **kwargs) -> BaseQueryEngine:
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
multi_modal_llm = get_multi_modal_llm()
if multi_modal_llm:
if get_multi_modal_llm():
kwargs["retrieve_image_nodes"] = True
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
response_synthesizer=MultiModalSynthesizer(
multimodal_model=multi_modal_llm
multimodal_model=get_multi_modal_llm(),
response_synthesizer=get_response_synthesizer(),
),
)

Expand Down

0 comments on commit 235159e

Please sign in to comment.