Skip to content

Commit

Permalink
rename writer to deep research
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Jan 21, 2025
1 parent ade3e2b commit 1162efd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

from app.engine.index import IndexConfig, get_index
from app.workflows.agents import plan_research, research, write_report
from app.workflows.events import SourceNodesEvent
from app.workflows.models import (
CollectAnswersEvent,
DataEvent,
PlanResearchEvent,
ReportEvent,
ResearchEvent,
WriteReportEvent,
)
from app.workflows.events import SourceNodesEvent

logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
Expand All @@ -43,16 +43,16 @@ def create_workflow(
"Index is not found. Try run generation script to create the index first."
)

return WriterWorkflow(
return DeepResearchWorkflow(
index=index,
chat_history=chat_history,
timeout=120.0,
)


class WriterWorkflow(Workflow):
class DeepResearchWorkflow(Workflow):
"""
A workflow to research and write a post for a specific topic.
A workflow to research and analyze documents from multiple perspectives and write a comprehensive report.
Requirements:
- An indexed documents containing the knowledge base related to the topic
Expand All @@ -61,7 +61,7 @@ class WriterWorkflow(Workflow):
1. Retrieve information from the knowledge base
2. Analyze the retrieved information and provide questions for answering
3. Answer the questions
4. Write the post based on the research results
4. Write the report based on the research results
"""

memory: SimpleComposableMemory
Expand Down Expand Up @@ -104,7 +104,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "retrieve",
"state": "inprogress",
Expand All @@ -118,7 +118,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
self.context_nodes.extend(nodes)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "retrieve",
"state": "done",
Expand All @@ -139,14 +139,14 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
@step
async def analyze(
self, ctx: Context, ev: PlanResearchEvent
) -> ResearchEvent | WriteReportEvent | StopEvent:
) -> ResearchEvent | ReportEvent | StopEvent:
"""
Analyze the retrieved information
"""
logger.info("Analyzing the retrieved information")
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "analyze",
"state": "inprogress",
Expand All @@ -169,7 +169,7 @@ async def analyze(
content="No more idea to analyze. We should report the answers.",
)
)
ctx.send_event(WriteReportEvent())
ctx.send_event(ReportEvent())
else:
await ctx.set("n_questions", len(res.research_questions))
self.memory.put(
Expand All @@ -183,7 +183,7 @@ async def analyze(
question_id = str(uuid.uuid4())
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "pending",
Expand All @@ -202,7 +202,7 @@ async def analyze(
)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "analyze",
"state": "done",
Expand All @@ -218,7 +218,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
"""
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "inprogress",
Expand All @@ -237,7 +237,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
answer = f"Got error when answering the question: {ev.question}"
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "done",
Expand All @@ -257,7 +257,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
@step
async def collect_answers(
self, ctx: Context, ev: CollectAnswersEvent
) -> WriteReportEvent:
) -> ReportEvent:
"""
Collect answers to all questions
"""
Expand Down Expand Up @@ -285,7 +285,7 @@ async def collect_answers(
return PlanResearchEvent()

@step
async def report(self, ctx: Context, ev: WriteReportEvent) -> StopEvent:
async def report(self, ctx: Context, ev: ReportEvent) -> StopEvent:
"""
Report the answers
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class CollectAnswersEvent(Event):
answer: str


class WriteReportEvent(Event):
class ReportEvent(Event):
pass


# Events that are streamed to the frontend and rendered there
class WriterEventData(BaseModel):
class DeepResearchEventData(BaseModel):
event: Literal["retrieve", "analyze", "answer"]
state: Literal["pending", "inprogress", "done", "error"]
id: Optional[str] = None
Expand All @@ -36,8 +36,8 @@ class WriterEventData(BaseModel):


class DataEvent(Event):
type: Literal["writer_card"]
data: WriterEventData
type: Literal["deep_research_event"]
data: DeepResearchEventData

def to_response(self):
return self.model_dump()
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import {
useChatMessage,
useChatUI,
} from "@llamaindex/chat-ui";
import { DeepResearchCard } from "./custom/deep-research-card";
import { Markdown } from "./custom/markdown";
import { WriterCard } from "./custom/writer-card";
import { ToolAnnotations } from "./tools/chat-tools";

export function ChatMessageContent() {
Expand All @@ -26,7 +26,7 @@ export function ChatMessageContent() {
// add the writer card
{
position: ContentPosition.CHAT_EVENTS,
component: <WriterCard message={message} />,
component: <DeepResearchCard message={message} />,
},
{
// add the tool annotations after events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import { Markdown } from "./markdown";
// Streaming event types
type EventState = "pending" | "inprogress" | "done" | "error";

type WriterEvent = {
type: "writer_card";
type DeepResearchEvent = {
type: "deep_research_event";
data: {
event: "retrieve" | "analyze" | "answer";
state: EventState;
Expand All @@ -42,7 +42,7 @@ type QuestionState = {
isOpen: boolean;
};

type WriterState = {
type DeepResearchCardState = {
retrieve: {
state: EventState | null;
};
Expand All @@ -52,7 +52,7 @@ type WriterState = {
};
};

interface WriterCardProps {
interface DeepResearchCardProps {
message: Message;
className?: string;
}
Expand All @@ -66,9 +66,9 @@ const stateIcon: Record<EventState, React.ReactNode> = {

// Transform the state based on the event without mutations
const transformState = (
state: WriterState,
event: WriterEvent,
): WriterState => {
state: DeepResearchCardState,
event: DeepResearchEvent,
): DeepResearchCardState => {
switch (event.data.event) {
case "answer": {
const { id, question, answer } = event.data;
Expand Down Expand Up @@ -119,35 +119,46 @@ const transformState = (
}
};

// Convert writer events to state
const writeEventsToState = (events: WriterEvent[] | undefined): WriterState => {
// Convert deep research events to state
const deepResearchEventsToState = (
events: DeepResearchEvent[] | undefined,
): DeepResearchCardState => {
if (!events?.length) {
return {
retrieve: { state: null },
analyze: { state: null, questions: [] },
};
}

const initialState: WriterState = {
const initialState: DeepResearchCardState = {
retrieve: { state: null },
analyze: { state: null, questions: [] },
};

return events.reduce(
(acc: WriterState, event: WriterEvent) => transformState(acc, event),
(acc: DeepResearchCardState, event: DeepResearchEvent) =>
transformState(acc, event),
initialState,
);
};

export function WriterCard({ message, className }: WriterCardProps) {
const writerEvents = message.annotations as WriterEvent[] | undefined;
const hasWriterEvents = writerEvents?.some(
(event) => event.type === "writer_card",
export function DeepResearchCard({
message,
className,
}: DeepResearchCardProps) {
const deepResearchEvents = message.annotations as
| DeepResearchEvent[]
| undefined;
const hasDeepResearchEvents = deepResearchEvents?.some(
(event) => event.type === "deep_research_event",
);

const state = useMemo(() => writeEventsToState(writerEvents), [writerEvents]);
const state = useMemo(
() => deepResearchEventsToState(deepResearchEvents),
[deepResearchEvents],
);

if (!hasWriterEvents) {
if (!hasDeepResearchEvents) {
return null;
}

Expand Down

0 comments on commit 1162efd

Please sign in to comment.