diff --git a/.changeset/kind-mice-repair.md b/.changeset/kind-mice-repair.md new file mode 100644 index 000000000..0da4c2147 --- /dev/null +++ b/.changeset/kind-mice-repair.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Standardize the code of the workflow use case (Python) diff --git a/templates/components/agents/python/deep_research/app/workflows/deep_research.py b/templates/components/agents/python/deep_research/app/workflows/deep_research.py index 59648b524..6af650826 100644 --- a/templates/components/agents/python/deep_research/app/workflows/deep_research.py +++ b/templates/components/agents/python/deep_research/app/workflows/deep_research.py @@ -32,7 +32,6 @@ def create_workflow( - chat_history: Optional[List[ChatMessage]] = None, params: Optional[Dict[str, Any]] = None, **kwargs, ) -> Workflow: @@ -45,7 +44,6 @@ def create_workflow( return DeepResearchWorkflow( index=index, - chat_history=chat_history, timeout=120.0, ) @@ -73,19 +71,13 @@ class DeepResearchWorkflow(Workflow): def __init__( self, index: BaseIndex, - chat_history: Optional[List[ChatMessage]] = None, - stream: bool = True, **kwargs, ): super().__init__(**kwargs) self.index = index self.context_nodes = [] - self.stream = stream - self.chat_history = chat_history self.memory = SimpleComposableMemory.from_defaults( - primary_memory=ChatMemoryBuffer.from_defaults( - chat_history=chat_history, - ), + primary_memory=ChatMemoryBuffer.from_defaults(), ) @step @@ -93,8 +85,15 @@ async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent: """ Initiate the workflow: memory, tools, agent """ + self.stream = ev.get("stream", True) + self.user_request = ev.get("user_msg") + chat_history = ev.get("chat_history") + if chat_history is not None: + self.memory.put_messages(chat_history) + await ctx.set("total_questions", 0) - self.user_request = ev.get("input") + + # Add user message to memory self.memory.put_messages( messages=[ ChatMessage( @@ -319,7 +318,6 @@ async def report(self, ctx: Context, ev: ReportEvent) -> StopEvent: """ Report the answers """ - logger.info("Writing the report") res = await write_report( memory=self.memory, user_request=self.user_request, diff --git a/templates/components/agents/python/financial_report/app/workflows/financial_report.py b/templates/components/agents/python/financial_report/app/workflows/financial_report.py index adfdb27da..8f2abb9ee 100644 --- a/templates/components/agents/python/financial_report/app/workflows/financial_report.py +++ b/templates/components/agents/python/financial_report/app/workflows/financial_report.py @@ -1,13 +1,5 @@ from typing import Any, Dict, List, Optional -from app.engine.index import IndexConfig, get_index -from app.engine.tools import ToolFactory -from app.engine.tools.query_engine import get_query_engine_tool -from app.workflows.events import AgentRunEvent -from app.workflows.tools import ( - call_tools, - chat_with_tools, -) from llama_index.core import Settings from llama_index.core.base.llms.types import ChatMessage, MessageRole from llama_index.core.llms.function_calling import FunctionCallingLLM @@ -22,9 +14,17 @@ step, ) +from app.engine.index import IndexConfig, get_index +from app.engine.tools import ToolFactory +from app.engine.tools.query_engine import get_query_engine_tool +from app.workflows.events import AgentRunEvent +from app.workflows.tools import ( + call_tools, + chat_with_tools, +) + def create_workflow( - chat_history: Optional[List[ChatMessage]] = None, params: Optional[Dict[str, Any]] = None, **kwargs, ) -> Workflow: @@ -45,7 +45,6 @@ def create_workflow( query_engine_tool=query_engine_tool, code_interpreter_tool=code_interpreter_tool, document_generator_tool=document_generator_tool, - chat_history=chat_history, ) @@ -91,6 +90,7 @@ class FinancialReportWorkflow(Workflow): It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself. For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries. """ + stream: bool = True def __init__( self, @@ -99,12 +99,10 @@ def __init__( document_generator_tool: FunctionTool, llm: Optional[FunctionCallingLLM] = None, timeout: int = 360, - chat_history: Optional[List[ChatMessage]] = None, system_prompt: Optional[str] = None, ): super().__init__(timeout=timeout) self.system_prompt = system_prompt or self._default_system_prompt - self.chat_history = chat_history or [] self.query_engine_tool = query_engine_tool self.code_interpreter_tool = code_interpreter_tool self.document_generator_tool = document_generator_tool @@ -122,13 +120,19 @@ def __init__( ] self.llm: FunctionCallingLLM = llm or Settings.llm assert isinstance(self.llm, FunctionCallingLLM) - self.memory = ChatMemoryBuffer.from_defaults( - llm=self.llm, chat_history=self.chat_history - ) + self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm) @step() async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent: - ctx.data["input"] = ev.input + self.stream = ev.get("stream", True) + user_msg = ev.get("user_msg") + chat_history = ev.get("chat_history") + + if chat_history is not None: + self.memory.put_messages(chat_history) + + # Add user message to memory + self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg)) if self.system_prompt: system_msg = ChatMessage( @@ -136,9 +140,6 @@ async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent ) self.memory.put(system_msg) - # Add user input to memory - self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input)) - return InputEvent(input=self.memory.get()) @step() @@ -160,8 +161,10 @@ async def handle_llm_input( # type: ignore chat_history, ) if not response.has_tool_calls(): - # If no tool call, return the response generator - return StopEvent(result=response.generator) + if self.stream: + return StopEvent(result=response.generator) + else: + return StopEvent(result=await response.full_response()) # calling different tools at the same time is not supported at the moment # add an error message to tell the AI to process step by step if response.is_calling_different_tools(): diff --git a/templates/components/agents/python/form_filling/app/workflows/form_filling.py b/templates/components/agents/python/form_filling/app/workflows/form_filling.py index da728744d..6078a5072 100644 --- a/templates/components/agents/python/form_filling/app/workflows/form_filling.py +++ b/templates/components/agents/python/form_filling/app/workflows/form_filling.py @@ -25,7 +25,6 @@ def create_workflow( - chat_history: Optional[List[ChatMessage]] = None, params: Optional[Dict[str, Any]] = None, **kwargs, ) -> Workflow: @@ -45,7 +44,6 @@ def create_workflow( query_engine_tool=query_engine_tool, extractor_tool=extractor_tool, # type: ignore filling_tool=filling_tool, # type: ignore - chat_history=chat_history, ) return workflow @@ -88,6 +86,7 @@ class FormFillingWorkflow(Workflow): Only use provided data - never make up any information yourself. Fill N/A if an answer is not found. If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base. """ + stream: bool = True def __init__( self, @@ -96,12 +95,10 @@ def __init__( filling_tool: FunctionTool, llm: Optional[FunctionCallingLLM] = None, timeout: int = 360, - chat_history: Optional[List[ChatMessage]] = None, system_prompt: Optional[str] = None, ): super().__init__(timeout=timeout) self.system_prompt = system_prompt or self._default_system_prompt - self.chat_history = chat_history or [] self.query_engine_tool = query_engine_tool self.extractor_tool = extractor_tool self.filling_tool = filling_tool @@ -113,13 +110,18 @@ def __init__( self.llm: FunctionCallingLLM = llm or Settings.llm if not isinstance(self.llm, FunctionCallingLLM): raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.") - self.memory = ChatMemoryBuffer.from_defaults( - llm=self.llm, chat_history=self.chat_history - ) + self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm) @step() async def start(self, ctx: Context, ev: StartEvent) -> InputEvent: - ctx.data["input"] = ev.input + self.stream = ev.get("stream", True) + user_msg = ev.get("user_msg", "") + chat_history = ev.get("chat_history", []) + + if chat_history: + self.memory.put_messages(chat_history) + + self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg)) if self.system_prompt: system_msg = ChatMessage( @@ -127,12 +129,7 @@ async def start(self, ctx: Context, ev: StartEvent) -> InputEvent: ) self.memory.put(system_msg) - user_input = ev.input - user_msg = ChatMessage(role=MessageRole.USER, content=user_input) - self.memory.put(user_msg) - - chat_history = self.memory.get() - return InputEvent(input=chat_history) + return InputEvent(input=self.memory.get()) @step() async def handle_llm_input( # type: ignore @@ -150,7 +147,10 @@ async def handle_llm_input( # type: ignore chat_history, ) if not response.has_tool_calls(): - return StopEvent(result=response.generator) + if self.stream: + return StopEvent(result=response.generator) + else: + return StopEvent(result=await response.full_response()) # calling different tools at the same time is not supported at the moment # add an error message to tell the AI to process step by step if response.is_calling_different_tools(): diff --git a/templates/components/multiagent/python/app/api/callbacks/__init__.py b/templates/components/multiagent/python/app/api/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/templates/components/multiagent/python/app/api/callbacks/base.py b/templates/components/multiagent/python/app/api/callbacks/base.py new file mode 100644 index 000000000..0b979d171 --- /dev/null +++ b/templates/components/multiagent/python/app/api/callbacks/base.py @@ -0,0 +1,32 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger("uvicorn") + + +class EventCallback(ABC): + """ + Base class for event callbacks during event streaming. + """ + + async def run(self, event: Any) -> Any: + """ + Called for each event in the stream. + Default behavior: pass through the event unchanged. + """ + return event + + async def on_complete(self, final_response: str) -> Any: + """ + Called when the stream is complete. + Default behavior: return None. + """ + return None + + @abstractmethod + def from_default(self, *args, **kwargs) -> "EventCallback": + """ + Create a new instance of the processor from default values. + """ + pass diff --git a/templates/components/multiagent/python/app/api/callbacks/llamacloud.py b/templates/components/multiagent/python/app/api/callbacks/llamacloud.py new file mode 100644 index 000000000..a28b736f8 --- /dev/null +++ b/templates/components/multiagent/python/app/api/callbacks/llamacloud.py @@ -0,0 +1,42 @@ +import logging +from typing import Any, List + +from fastapi import BackgroundTasks +from llama_index.core.schema import NodeWithScore + +from app.api.callbacks.base import EventCallback + +logger = logging.getLogger("uvicorn") + + +class LlamaCloudFileDownload(EventCallback): + """ + Processor for handling LlamaCloud file downloads from source nodes. + Only work if LlamaCloud service code is available. + """ + + def __init__(self, background_tasks: BackgroundTasks): + self.background_tasks = background_tasks + + async def run(self, event: Any) -> Any: + if hasattr(event, "to_response"): + event_response = event.to_response() + if event_response.get("type") == "sources" and hasattr(event, "nodes"): + await self._process_response_nodes(event.nodes) + return event + + async def _process_response_nodes(self, source_nodes: List[NodeWithScore]): + try: + from app.engine.service import LLamaCloudFileService # type: ignore + + LLamaCloudFileService.download_files_from_nodes( + source_nodes, self.background_tasks + ) + except ImportError: + pass + + @classmethod + def from_default( + cls, background_tasks: BackgroundTasks + ) -> "LlamaCloudFileDownload": + return cls(background_tasks=background_tasks) diff --git a/templates/components/multiagent/python/app/api/callbacks/next_question.py b/templates/components/multiagent/python/app/api/callbacks/next_question.py new file mode 100644 index 000000000..57c223e98 --- /dev/null +++ b/templates/components/multiagent/python/app/api/callbacks/next_question.py @@ -0,0 +1,34 @@ +import logging +from typing import Any + +from app.api.callbacks.base import EventCallback +from app.api.routers.models import ChatData +from app.api.services.suggestion import NextQuestionSuggestion + +logger = logging.getLogger("uvicorn") + + +class SuggestNextQuestions(EventCallback): + """Processor for generating next question suggestions.""" + + def __init__(self, chat_data: ChatData): + self.chat_data = chat_data + self.accumulated_text = "" + + async def on_complete(self, final_response: str) -> Any: + if final_response == "": + return None + + questions = await NextQuestionSuggestion.suggest_next_questions( + self.chat_data.messages, final_response + ) + if questions: + return { + "type": "suggested_questions", + "data": questions, + } + return None + + @classmethod + def from_default(cls, chat_data: ChatData) -> "SuggestNextQuestions": + return cls(chat_data=chat_data) diff --git a/templates/components/multiagent/python/app/api/callbacks/stream_handler.py b/templates/components/multiagent/python/app/api/callbacks/stream_handler.py new file mode 100644 index 000000000..0167a85bb --- /dev/null +++ b/templates/components/multiagent/python/app/api/callbacks/stream_handler.py @@ -0,0 +1,66 @@ +import logging +from typing import List, Optional + +from llama_index.core.workflow.handler import WorkflowHandler + +from app.api.callbacks.base import EventCallback + +logger = logging.getLogger("uvicorn") + + +class StreamHandler: + """ + Streams events from a workflow handler through a chain of callbacks. + """ + + def __init__( + self, + workflow_handler: WorkflowHandler, + callbacks: Optional[List[EventCallback]] = None, + ): + self.workflow_handler = workflow_handler + self.callbacks = callbacks or [] + self.accumulated_text = "" + + def vercel_stream(self): + """Create a streaming response with Vercel format.""" + from app.api.routers.vercel_response import VercelStreamResponse + + return VercelStreamResponse(stream_handler=self) + + async def cancel_run(self): + """Cancel the workflow handler.""" + await self.workflow_handler.cancel_run() + + async def stream_events(self): + """Stream events through the processor chain.""" + try: + async for event in self.workflow_handler.stream_events(): + # Process the event through each processor + for callback in self.callbacks: + event = await callback.run(event) + yield event + + # After all events are processed, call on_complete for each callback + for callback in self.callbacks: + result = await callback.on_complete(self.accumulated_text) + if result: + yield result + + except Exception as e: + # Make sure to cancel the workflow on error + await self.workflow_handler.cancel_run() + raise e + + async def accumulate_text(self, text: str): + """Accumulate text from the workflow handler.""" + self.accumulated_text += text + + @classmethod + def from_default( + cls, + handler: WorkflowHandler, + callbacks: Optional[List[EventCallback]] = None, + ) -> "StreamHandler": + """Create a new instance with the given workflow handler and callbacks.""" + return cls(workflow_handler=handler, callbacks=callbacks) diff --git a/templates/components/multiagent/python/app/api/routers/chat.py b/templates/components/multiagent/python/app/api/routers/chat.py index a16545f71..d7a44e691 100644 --- a/templates/components/multiagent/python/app/api/routers/chat.py +++ b/templates/components/multiagent/python/app/api/routers/chat.py @@ -2,10 +2,12 @@ from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status +from app.api.callbacks.llamacloud import LlamaCloudFileDownload +from app.api.callbacks.next_question import SuggestNextQuestions +from app.api.callbacks.stream_handler import StreamHandler from app.api.routers.models import ( ChatData, ) -from app.api.routers.vercel_response import VercelStreamResponse from app.engine.query_filter import generate_filters from app.workflows import create_workflow @@ -29,19 +31,22 @@ async def chat( params = data.data or {} workflow = create_workflow( - chat_history=messages, params=params, filters=filters, ) - event_handler = workflow.run(input=last_message_content, streaming=True) - return VercelStreamResponse( - request=request, - chat_data=data, - background_tasks=background_tasks, - event_handler=event_handler, - events=workflow.stream_events(), + handler = workflow.run( + user_msg=last_message_content, + chat_history=messages, + stream=True, ) + return StreamHandler.from_default( + handler=handler, + callbacks=[ + LlamaCloudFileDownload.from_default(background_tasks), + SuggestNextQuestions.from_default(data), + ], + ).vercel_stream() except Exception as e: logger.exception("Error in chat engine", exc_info=True) raise HTTPException( diff --git a/templates/components/multiagent/python/app/api/routers/vercel_response.py b/templates/components/multiagent/python/app/api/routers/vercel_response.py index d7609ee43..a5f1d7a01 100644 --- a/templates/components/multiagent/python/app/api/routers/vercel_response.py +++ b/templates/components/multiagent/python/app/api/routers/vercel_response.py @@ -1,22 +1,20 @@ import asyncio import json import logging -from typing import AsyncGenerator, Awaitable, List +from typing import AsyncGenerator -from aiostream import stream -from fastapi import BackgroundTasks, Request from fastapi.responses import StreamingResponse -from llama_index.core.schema import NodeWithScore +from llama_index.core.agent.workflow.workflow_events import AgentStream +from llama_index.core.workflow import StopEvent -from app.api.routers.models import ChatData, Message -from app.api.services.suggestion import NextQuestionSuggestion +from app.api.callbacks.stream_handler import StreamHandler logger = logging.getLogger("uvicorn") class VercelStreamResponse(StreamingResponse): """ - Base class to convert the response from the chat engine to the streaming format expected by Vercel + Converts preprocessed events into Vercel-compatible streaming response format. """ TEXT_PREFIX = "0:" @@ -25,136 +23,77 @@ class VercelStreamResponse(StreamingResponse): def __init__( self, - request: Request, - chat_data: ChatData, - background_tasks: BackgroundTasks, + stream_handler: StreamHandler, *args, **kwargs, ): - self.request = request - self.chat_data = chat_data - self.background_tasks = background_tasks - content = self.content_generator(*args, **kwargs) - super().__init__(content=content) + self.handler = stream_handler + super().__init__(content=self.content_generator()) - async def content_generator(self, event_handler, events): - stream = self._create_stream( - self.request, self.chat_data, event_handler, events - ) - is_stream_started = False + async def content_generator(self): + """Generate Vercel-formatted content from preprocessed events.""" + stream_started = False try: - async with stream.stream() as streamer: - async for output in streamer: - if not is_stream_started: - is_stream_started = True - # Stream a blank message to start the stream - yield self.convert_text("") + async for event in self.handler.stream_events(): + if not stream_started: + # Start the stream with an empty message + stream_started = True + yield self.convert_text("") + + # Handle different types of events + if isinstance(event, (AgentStream, StopEvent)): + async for chunk in self._stream_text(event): + await self.handler.accumulate_text(chunk) + yield self.convert_text(chunk) + elif isinstance(event, dict): + yield self.convert_data(event) + elif hasattr(event, "to_response"): + event_response = event.to_response() + yield self.convert_data(event_response) + else: + yield self.convert_data(event.model_dump()) - yield output except asyncio.CancelledError: - logger.warning("Workflow has been cancelled!") + logger.warning("Client cancelled the request!") + await self.handler.cancel_run() except Exception as e: - logger.error( - f"Unexpected error in content_generator: {str(e)}", exc_info=True - ) - yield self.convert_error( - "An unexpected error occurred while processing your request, preventing the creation of a final answer. Please try again." - ) - finally: - await event_handler.cancel_run() - logger.info("The stream has been stopped!") - - def _create_stream( - self, - request: Request, - chat_data: ChatData, - event_handler: Awaitable, - events: AsyncGenerator, - verbose: bool = True, - ): - # Yield the text response - async def _chat_response_generator(): - result = await event_handler - final_response = "" - - if isinstance(result, AsyncGenerator): - async for token in result: - final_response += str(token.delta) - yield self.convert_text(token.delta) - else: - if hasattr(result, "response"): - content = result.response.message.content - if content: - for token in content: - final_response += str(token) - yield self.convert_text(token) - else: - final_response += str(result) - yield self.convert_text(result) - - # Generate next questions if next question prompt is configured - question_data = await self._generate_next_questions( - chat_data.messages, final_response - ) - if question_data: - yield self.convert_data(question_data) - - # Yield the events from the event handler - async def _event_generator(): - async for event in events: - event_response = event.to_response() - if verbose: - logger.debug(event_response) - if event_response is not None: - yield self.convert_data(event_response) - if event_response.get("type") == "sources": - self._process_response_nodes(event.nodes, self.background_tasks) - - combine = stream.merge(_chat_response_generator(), _event_generator()) - return combine - - @staticmethod - def _process_response_nodes( - source_nodes: List[NodeWithScore], - background_tasks: BackgroundTasks, - ): - try: - # Start background tasks to download documents from LlamaCloud if needed - from app.engine.service import LLamaCloudFileService # type: ignore - - LLamaCloudFileService.download_files_from_nodes( - source_nodes, background_tasks - ) - except ImportError: - logger.debug( - "LlamaCloud is not configured. Skipping post processing of nodes" - ) - pass + logger.error(f"Error in stream response: {e}") + yield self.convert_error(str(e)) + await self.handler.cancel_run() + + async def _stream_text( + self, event: AgentStream | StopEvent + ) -> AsyncGenerator[str, None]: + """ + Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result + """ + if isinstance(event, AgentStream): + yield self.convert_text(event.delta) + elif isinstance(event, StopEvent): + if isinstance(event.result, str): + yield event.result + elif isinstance(event.result, AsyncGenerator): + async for chunk in event.result: + if isinstance(chunk, str): + yield chunk + elif hasattr(chunk, "delta"): + yield chunk.delta @classmethod - def convert_text(cls, token: str): + def convert_text(cls, token: str) -> str: + """Convert text event to Vercel format.""" # Escape newlines and double quotes to avoid breaking the stream token = json.dumps(token) return f"{cls.TEXT_PREFIX}{token}\n" @classmethod - def convert_data(cls, data: dict): + def convert_data(cls, data: dict) -> str: + """Convert data event to Vercel format.""" data_str = json.dumps(data) return f"{cls.DATA_PREFIX}[{data_str}]\n" @classmethod - def convert_error(cls, error: str): + def convert_error(cls, error: str) -> str: + """Convert error event to Vercel format.""" error_str = json.dumps(error) return f"{cls.ERROR_PREFIX}{error_str}\n" - - @staticmethod - async def _generate_next_questions(chat_history: List[Message], response: str): - questions = await NextQuestionSuggestion.suggest_next_questions( - chat_history, response - ) - if questions: - return { - "type": "suggested_questions", - "data": questions, - } - return None diff --git a/templates/components/multiagent/python/app/workflows/tools.py b/templates/components/multiagent/python/app/workflows/tools.py index 75058c493..faab45955 100644 --- a/templates/components/multiagent/python/app/workflows/tools.py +++ b/templates/components/multiagent/python/app/workflows/tools.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Callable, Optional -from app.workflows.events import AgentRunEvent, AgentRunEventType from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.tools import ( @@ -15,6 +14,8 @@ from llama_index.core.workflow import Context from pydantic import BaseModel, ConfigDict +from app.workflows.events import AgentRunEvent, AgentRunEventType + logger = logging.getLogger("uvicorn") @@ -51,7 +52,9 @@ async def full_response(self) -> str: assert self.generator is not None full_response = "" async for chunk in self.generator: - full_response += chunk.message.content + content = chunk.message.content + if content: + full_response += content return full_response