Skip to content

Commit

Permalink
rename writer to deep research
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Jan 21, 2025
1 parent ade3e2b commit c4020cb
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ The workflow writes blog posts based on documents in the [data](./data) director
After starting the server, go to [http://localhost:8000](http://localhost:8000) and send a message to the agent to write a blog post.
E.g: "Write a post about AI investment in 2024"

To update the workflow, you can edit the [writer.py](./app/workflows/writer.py) file.
To update the workflow, you can edit the [deep_research.py](./app/workflows/deep_research.py) file.

By default, the workflow retrieves 10 results from your documents. To customize the amount of information covered in the answer, you can adjust the `TOP_K` environment variable in the `.env` file. A higher value will retrieve more results from your documents, potentially providing more comprehensive answers.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

from app.engine.index import IndexConfig, get_index
from app.workflows.agents import plan_research, research, write_report
from app.workflows.events import SourceNodesEvent
from app.workflows.models import (
CollectAnswersEvent,
DataEvent,
PlanResearchEvent,
ReportEvent,
ResearchEvent,
WriteReportEvent,
)
from app.workflows.events import SourceNodesEvent

logger = logging.getLogger("uvicorn")
logger.setLevel(logging.INFO)
Expand All @@ -43,16 +43,16 @@ def create_workflow(
"Index is not found. Try run generation script to create the index first."
)

return WriterWorkflow(
return DeepResearchWorkflow(
index=index,
chat_history=chat_history,
timeout=120.0,
)


class WriterWorkflow(Workflow):
class DeepResearchWorkflow(Workflow):
"""
A workflow to research and write a post for a specific topic.
A workflow to research and analyze documents from multiple perspectives and write a comprehensive report.
Requirements:
- An indexed documents containing the knowledge base related to the topic
Expand All @@ -61,7 +61,7 @@ class WriterWorkflow(Workflow):
1. Retrieve information from the knowledge base
2. Analyze the retrieved information and provide questions for answering
3. Answer the questions
4. Write the post based on the research results
4. Write the report based on the research results
"""

memory: SimpleComposableMemory
Expand Down Expand Up @@ -104,7 +104,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "retrieve",
"state": "inprogress",
Expand All @@ -118,7 +118,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
self.context_nodes.extend(nodes)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "retrieve",
"state": "done",
Expand All @@ -139,14 +139,14 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
@step
async def analyze(
self, ctx: Context, ev: PlanResearchEvent
) -> ResearchEvent | WriteReportEvent | StopEvent:
) -> ResearchEvent | ReportEvent | StopEvent:
"""
Analyze the retrieved information
"""
logger.info("Analyzing the retrieved information")
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "analyze",
"state": "inprogress",
Expand All @@ -169,7 +169,7 @@ async def analyze(
content="No more idea to analyze. We should report the answers.",
)
)
ctx.send_event(WriteReportEvent())
ctx.send_event(ReportEvent())
else:
await ctx.set("n_questions", len(res.research_questions))
self.memory.put(
Expand All @@ -183,7 +183,7 @@ async def analyze(
question_id = str(uuid.uuid4())
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "pending",
Expand All @@ -202,7 +202,7 @@ async def analyze(
)
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "analyze",
"state": "done",
Expand All @@ -218,7 +218,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
"""
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "inprogress",
Expand All @@ -237,7 +237,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
answer = f"Got error when answering the question: {ev.question}"
ctx.write_event_to_stream(
DataEvent(
type="writer_card",
type="deep_research_event",
data={
"event": "answer",
"state": "done",
Expand All @@ -257,7 +257,7 @@ async def answer(self, ctx: Context, ev: ResearchEvent) -> CollectAnswersEvent:
@step
async def collect_answers(
self, ctx: Context, ev: CollectAnswersEvent
) -> WriteReportEvent:
) -> ReportEvent:
"""
Collect answers to all questions
"""
Expand Down Expand Up @@ -285,7 +285,7 @@ async def collect_answers(
return PlanResearchEvent()

@step
async def report(self, ctx: Context, ev: WriteReportEvent) -> StopEvent:
async def report(self, ctx: Context, ev: ReportEvent) -> StopEvent:
"""
Report the answers
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class CollectAnswersEvent(Event):
answer: str


class WriteReportEvent(Event):
class ReportEvent(Event):
pass


# Events that are streamed to the frontend and rendered there
class WriterEventData(BaseModel):
class DeepResearchEventData(BaseModel):
event: Literal["retrieve", "analyze", "answer"]
state: Literal["pending", "inprogress", "done", "error"]
id: Optional[str] = None
Expand All @@ -36,8 +36,8 @@ class WriterEventData(BaseModel):


class DataEvent(Event):
type: Literal["writer_card"]
data: WriterEventData
type: Literal["deep_research_event"]
data: DeepResearchEventData

def to_response(self):
return self.model_dump()
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import {
useChatMessage,
useChatUI,
} from "@llamaindex/chat-ui";
import { DeepResearchCard } from "./custom/deep-research-card";
import { Markdown } from "./custom/markdown";
import { WriterCard } from "./custom/writer-card";
import { ToolAnnotations } from "./tools/chat-tools";

export function ChatMessageContent() {
Expand All @@ -23,10 +23,10 @@ export function ChatMessageContent() {
/>
),
},
// add the writer card
// add the deep research card
{
position: ContentPosition.CHAT_EVENTS,
component: <WriterCard message={message} />,
component: <DeepResearchCard message={message} />,
},
{
// add the tool annotations after events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import { Markdown } from "./markdown";
// Streaming event types
type EventState = "pending" | "inprogress" | "done" | "error";

type WriterEvent = {
type: "writer_card";
type DeepResearchEvent = {
type: "deep_research_event";
data: {
event: "retrieve" | "analyze" | "answer";
state: EventState;
Expand All @@ -42,7 +42,7 @@ type QuestionState = {
isOpen: boolean;
};

type WriterState = {
type DeepResearchCardState = {
retrieve: {
state: EventState | null;
};
Expand All @@ -52,7 +52,7 @@ type WriterState = {
};
};

interface WriterCardProps {
interface DeepResearchCardProps {
message: Message;
className?: string;
}
Expand All @@ -66,9 +66,9 @@ const stateIcon: Record<EventState, React.ReactNode> = {

// Transform the state based on the event without mutations
const transformState = (
state: WriterState,
event: WriterEvent,
): WriterState => {
state: DeepResearchCardState,
event: DeepResearchEvent,
): DeepResearchCardState => {
switch (event.data.event) {
case "answer": {
const { id, question, answer } = event.data;
Expand Down Expand Up @@ -119,35 +119,46 @@ const transformState = (
}
};

// Convert writer events to state
const writeEventsToState = (events: WriterEvent[] | undefined): WriterState => {
// Convert deep research events to state
const deepResearchEventsToState = (
events: DeepResearchEvent[] | undefined,
): DeepResearchCardState => {
if (!events?.length) {
return {
retrieve: { state: null },
analyze: { state: null, questions: [] },
};
}

const initialState: WriterState = {
const initialState: DeepResearchCardState = {
retrieve: { state: null },
analyze: { state: null, questions: [] },
};

return events.reduce(
(acc: WriterState, event: WriterEvent) => transformState(acc, event),
(acc: DeepResearchCardState, event: DeepResearchEvent) =>
transformState(acc, event),
initialState,
);
};

export function WriterCard({ message, className }: WriterCardProps) {
const writerEvents = message.annotations as WriterEvent[] | undefined;
const hasWriterEvents = writerEvents?.some(
(event) => event.type === "writer_card",
export function DeepResearchCard({
message,
className,
}: DeepResearchCardProps) {
const deepResearchEvents = message.annotations as
| DeepResearchEvent[]
| undefined;
const hasDeepResearchEvents = deepResearchEvents?.some(
(event) => event.type === "deep_research_event",
);

const state = useMemo(() => writeEventsToState(writerEvents), [writerEvents]);
const state = useMemo(
() => deepResearchEventsToState(deepResearchEvents),
[deepResearchEvents],
);

if (!hasWriterEvents) {
if (!hasDeepResearchEvents) {
return null;
}

Expand Down

0 comments on commit c4020cb

Please sign in to comment.