Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize the code of workflow use cases #495

Merged
merged 17 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/kind-mice-repair.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Standardize the code of the workflow use case (Python)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
Expand All @@ -45,7 +44,6 @@ def create_workflow(

return DeepResearchWorkflow(
index=index,
chat_history=chat_history,
timeout=120.0,
)

Expand Down Expand Up @@ -73,28 +71,29 @@ class DeepResearchWorkflow(Workflow):
def __init__(
self,
index: BaseIndex,
chat_history: Optional[List[ChatMessage]] = None,
stream: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.index = index
self.context_nodes = []
self.stream = stream
self.chat_history = chat_history
self.memory = SimpleComposableMemory.from_defaults(
primary_memory=ChatMemoryBuffer.from_defaults(
chat_history=chat_history,
),
primary_memory=ChatMemoryBuffer.from_defaults(),
)

@step
async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
"""
Initiate the workflow: memory, tools, agent
"""
self.stream = ev.get("stream", True)
self.user_request = ev.get("user_msg")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also use user_msg to be consistent with the AgentWorkflow.

chat_history = ev.get("chat_history")
if chat_history is not None:
self.memory.put_messages(chat_history)

await ctx.set("total_questions", 0)
self.user_request = ev.get("input")

# Add user message to memory
self.memory.put_messages(
messages=[
ChatMessage(
Expand Down Expand Up @@ -319,7 +318,6 @@ async def report(self, ctx: Context, ev: ReportEvent) -> StopEvent:
"""
Report the answers
"""
logger.info("Writing the report")
res = await write_report(
memory=self.memory,
user_request=self.user_request,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.llms.function_calling import FunctionCallingLLM
Expand All @@ -22,9 +14,17 @@
step,
)

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
Expand All @@ -45,7 +45,6 @@ def create_workflow(
query_engine_tool=query_engine_tool,
code_interpreter_tool=code_interpreter_tool,
document_generator_tool=document_generator_tool,
chat_history=chat_history,
)


Expand Down Expand Up @@ -91,6 +90,7 @@ class FinancialReportWorkflow(Workflow):
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
"""
stream: bool = True

def __init__(
self,
Expand All @@ -99,12 +99,10 @@ def __init__(
document_generator_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
chat_history: Optional[List[ChatMessage]] = None,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.chat_history = chat_history or []
self.query_engine_tool = query_engine_tool
self.code_interpreter_tool = code_interpreter_tool
self.document_generator_tool = document_generator_tool
Expand All @@ -122,23 +120,26 @@ def __init__(
]
self.llm: FunctionCallingLLM = llm or Settings.llm
assert isinstance(self.llm, FunctionCallingLLM)
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=self.chat_history
)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)

@step()
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["input"] = ev.input
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg")
chat_history = ev.get("chat_history")

if chat_history is not None:
self.memory.put_messages(chat_history)

# Add user message to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))

if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)

# Add user input to memory
self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input))

return InputEvent(input=self.memory.get())

@step()
Expand All @@ -160,8 +161,10 @@ async def handle_llm_input( # type: ignore
chat_history,
)
if not response.has_tool_calls():
# If no tool call, return the response generator
return StopEvent(result=response.generator)
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
Expand All @@ -45,7 +44,6 @@ def create_workflow(
query_engine_tool=query_engine_tool,
extractor_tool=extractor_tool, # type: ignore
filling_tool=filling_tool, # type: ignore
chat_history=chat_history,
)

return workflow
Expand Down Expand Up @@ -88,6 +86,7 @@ class FormFillingWorkflow(Workflow):
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
"""
stream: bool = True

def __init__(
self,
Expand All @@ -96,12 +95,10 @@ def __init__(
filling_tool: FunctionTool,
llm: Optional[FunctionCallingLLM] = None,
timeout: int = 360,
chat_history: Optional[List[ChatMessage]] = None,
system_prompt: Optional[str] = None,
):
super().__init__(timeout=timeout)
self.system_prompt = system_prompt or self._default_system_prompt
self.chat_history = chat_history or []
self.query_engine_tool = query_engine_tool
self.extractor_tool = extractor_tool
self.filling_tool = filling_tool
Expand All @@ -113,26 +110,26 @@ def __init__(
self.llm: FunctionCallingLLM = llm or Settings.llm
if not isinstance(self.llm, FunctionCallingLLM):
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
self.memory = ChatMemoryBuffer.from_defaults(
llm=self.llm, chat_history=self.chat_history
)
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)

@step()
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
ctx.data["input"] = ev.input
self.stream = ev.get("stream", True)
user_msg = ev.get("user_msg", "")
chat_history = ev.get("chat_history", [])

if chat_history:
self.memory.put_messages(chat_history)

self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))

if self.system_prompt:
system_msg = ChatMessage(
role=MessageRole.SYSTEM, content=self.system_prompt
)
self.memory.put(system_msg)

user_input = ev.input
user_msg = ChatMessage(role=MessageRole.USER, content=user_input)
self.memory.put(user_msg)

chat_history = self.memory.get()
return InputEvent(input=chat_history)
return InputEvent(input=self.memory.get())

@step()
async def handle_llm_input( # type: ignore
Expand All @@ -150,7 +147,10 @@ async def handle_llm_input( # type: ignore
chat_history,
)
if not response.has_tool_calls():
return StopEvent(result=response.generator)
if self.stream:
return StopEvent(result=response.generator)
else:
return StopEvent(result=await response.full_response())
# calling different tools at the same time is not supported at the moment
# add an error message to tell the AI to process step by step
if response.is_calling_different_tools():
Expand Down
Empty file.
32 changes: 32 additions & 0 deletions templates/components/multiagent/python/app/api/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
from abc import ABC, abstractmethod
from typing import Any

logger = logging.getLogger("uvicorn")


class EventCallback(ABC):
"""
Base class for event callbacks during event streaming.
"""

async def run(self, event: Any) -> Any:
"""
Called for each event in the stream.
Default behavior: pass through the event unchanged.
"""
return event

async def on_complete(self, final_response: str) -> Any:
"""
Called when the stream is complete.
Default behavior: return None.
"""
return None

@abstractmethod
def from_default(self, *args, **kwargs) -> "EventCallback":
"""
Create a new instance of the processor from default values.
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from typing import Any, List

from fastapi import BackgroundTasks
from llama_index.core.schema import NodeWithScore

from app.api.callbacks.base import EventCallback

logger = logging.getLogger("uvicorn")


class LlamaCloudFileDownload(EventCallback):
"""
Processor for handling LlamaCloud file downloads from source nodes.
Only work if LlamaCloud service code is available.
"""

def __init__(self, background_tasks: BackgroundTasks):
self.background_tasks = background_tasks

async def run(self, event: Any) -> Any:
if hasattr(event, "to_response"):
event_response = event.to_response()
if event_response.get("type") == "sources" and hasattr(event, "nodes"):
await self._process_response_nodes(event.nodes)
return event

async def _process_response_nodes(self, source_nodes: List[NodeWithScore]):
try:
from app.engine.service import LLamaCloudFileService # type: ignore

LLamaCloudFileService.download_files_from_nodes(
source_nodes, self.background_tasks
)
except ImportError:
pass

@classmethod
def from_default(
cls, background_tasks: BackgroundTasks
) -> "LlamaCloudFileDownload":
return cls(background_tasks=background_tasks)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from typing import Any

from app.api.callbacks.base import EventCallback
from app.api.routers.models import ChatData
from app.api.services.suggestion import NextQuestionSuggestion

logger = logging.getLogger("uvicorn")


class SuggestNextQuestions(EventCallback):
"""Processor for generating next question suggestions."""

def __init__(self, chat_data: ChatData):
self.chat_data = chat_data
self.accumulated_text = ""

async def on_complete(self, final_response: str) -> Any:
if final_response == "":
return None

questions = await NextQuestionSuggestion.suggest_next_questions(
self.chat_data.messages, final_response
)
if questions:
return {
"type": "suggested_questions",
"data": questions,
}
return None
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_default(cls, chat_data: ChatData) -> "SuggestNextQuestions":
return cls(chat_data=chat_data)
Loading
Loading