Skip to content

Commit

Permalink
feat: support multi agent for ts (#300)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: Marcus Schiesser <[email protected]>
  • Loading branch information
thucpn authored Sep 26, 2024
1 parent 70f7dca commit ef070c0
Show file tree
Hide file tree
Showing 15 changed files with 638 additions and 191 deletions.
5 changes: 5 additions & 0 deletions .changeset/yellow-jokes-protect.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Add multi agents template for Typescript
12 changes: 6 additions & 6 deletions e2e/multiagent_template.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ import type {
} from "../helpers";
import { createTestDir, runCreateLlama, type AppType } from "./utils";

const templateFramework: TemplateFramework = "fastapi";
const templateFramework: TemplateFramework = process.env.FRAMEWORK
? (process.env.FRAMEWORK as TemplateFramework)
: "fastapi";
const dataSource: string = "--example-file";
const templateUI: TemplateUI = "shadcn";
const templatePostInstallAction: TemplatePostInstallAction = "runApp";
const appType: AppType = "--frontend";
const appType: AppType = templateFramework === "nextjs" ? "" : "--frontend";
const userMessage = "Write a blog post about physical standards for letters";

test.describe(`Test multiagent template ${templateFramework} ${dataSource} ${templateUI} ${appType} ${templatePostInstallAction}`, async () => {
test.skip(
process.platform !== "linux" ||
process.env.FRAMEWORK !== "fastapi" ||
process.env.DATASOURCE === "--no-files",
"The multiagent template currently only works with FastAPI and files. We also only run on Linux to speed up tests.",
process.platform !== "linux" || process.env.DATASOURCE === "--no-files",
"The multiagent template currently only works with files. We also only run on Linux to speed up tests.",
);
let port: number;
let externalPort: number;
Expand Down
32 changes: 30 additions & 2 deletions helpers/typescript.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ export const installTSTemplate = async ({
* Copy the template files to the target directory.
*/
console.log("\nInitializing project with template:", template, "\n");
const type = template === "multiagent" ? "streaming" : template; // use nextjs streaming template for multiagent
const templatePath = path.join(templatesDir, "types", type, framework);
const templatePath = path.join(templatesDir, "types", "streaming", framework);
const copySource = ["**"];

await copy(copySource, root, {
Expand Down Expand Up @@ -124,6 +123,30 @@ export const installTSTemplate = async ({
cwd: path.join(compPath, "vectordbs", "typescript", vectorDb ?? "none"),
});

if (template === "multiagent") {
const multiagentPath = path.join(compPath, "multiagent", "typescript");

// copy workflow code for multiagent template
await copy("**", path.join(root, relativeEngineDestPath, "workflow"), {
parents: true,
cwd: path.join(multiagentPath, "workflow"),
});

if (framework === "nextjs") {
// patch route.ts file
await copy("**", path.join(root, relativeEngineDestPath), {
parents: true,
cwd: path.join(multiagentPath, "nextjs"),
});
} else if (framework === "express") {
// patch chat.controller.ts file
await copy("**", path.join(root, relativeEngineDestPath), {
parents: true,
cwd: path.join(multiagentPath, "express"),
});
}
}

// copy loader component (TS only supports llama_parse and file for now)
const loaderFolder = useLlamaParse ? "llama_parse" : "file";
await copy("**", enginePath, {
Expand All @@ -145,6 +168,11 @@ export const installTSTemplate = async ({
cwd: path.join(compPath, "engines", "typescript", engine),
});

// copy settings to engine folder
await copy("**", enginePath, {
cwd: path.join(compPath, "settings", "typescript"),
});

/**
* Copy the selected UI files to the target directory and reference it.
*/
Expand Down
5 changes: 1 addition & 4 deletions questions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,7 @@ export const askQuestions = async (
return; // early return - no further questions needed for llamapack projects
}

if (program.template === "multiagent") {
// TODO: multi-agents currently only supports FastAPI
program.framework = preferences.framework = "fastapi";
} else if (program.template === "extractor") {
if (program.template === "extractor") {
// Extractor template only supports FastAPI, empty data sources, and llamacloud
// So we just use example file for extractor template, this allows user to choose vector database later
program.dataSources = [EXAMPLE_FILE];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import { StopEvent } from "@llamaindex/core/workflow";
import { Message, streamToResponse } from "ai";
import { Request, Response } from "express";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { createWorkflow } from "./workflow/factory";
import { toDataStream, workflowEventsToStreamData } from "./workflow/stream";

export const chat = async (req: Request, res: Response) => {
try {
const { messages }: { messages: Message[] } = req.body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return res.status(400).json({
error:
"messages are required in the request body and the last message must be from the user",
});
}

const chatHistory = messages as ChatMessage[];
const agent = createWorkflow(chatHistory);
const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
userMessage.content,
) as unknown as Promise<StopEvent<AsyncGenerator<ChatResponseChunk>>>;

// convert the workflow events to a vercel AI stream data object
const agentStreamData = await workflowEventsToStreamData(
agent.streamEvents(),
);
// convert the workflow result to a vercel AI content stream
const stream = toDataStream(result, {
onFinal: () => agentStreamData.close(),
});

return streamToResponse(stream, res, {}, agentStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return res.status(500).json({
detail: (error as Error).message,
});
}
};
57 changes: 57 additions & 0 deletions templates/components/multiagent/typescript/nextjs/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { initObservability } from "@/app/observability";
import { StopEvent } from "@llamaindex/core/workflow";
import { Message, StreamingTextResponse } from "ai";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { initSettings } from "./engine/settings";
import { createWorkflow } from "./workflow/factory";
import { toDataStream, workflowEventsToStreamData } from "./workflow/stream";

initObservability();
initSettings();

export const runtime = "nodejs";
export const dynamic = "force-dynamic";

export async function POST(request: NextRequest) {
try {
const body = await request.json();
const { messages }: { messages: Message[] } = body;
const userMessage = messages.pop();
if (!messages || !userMessage || userMessage.role !== "user") {
return NextResponse.json(
{
error:
"messages are required in the request body and the last message must be from the user",
},
{ status: 400 },
);
}

const chatHistory = messages as ChatMessage[];
const agent = createWorkflow(chatHistory);
// TODO: fix type in agent.run in LITS
const result = agent.run<AsyncGenerator<ChatResponseChunk>>(
userMessage.content,
) as unknown as Promise<StopEvent<AsyncGenerator<ChatResponseChunk>>>;
// convert the workflow events to a vercel AI stream data object
const agentStreamData = await workflowEventsToStreamData(
agent.streamEvents(),
);
// convert the workflow result to a vercel AI content stream
const stream = toDataStream(result, {
onFinal: () => agentStreamData.close(),
});
return new StreamingTextResponse(stream, {}, agentStreamData);
} catch (error) {
console.error("[LlamaIndex]", error);
return NextResponse.json(
{
detail: (error as Error).message,
},
{
status: 500,
},
);
}
}
51 changes: 51 additions & 0 deletions templates/components/multiagent/typescript/workflow/agents.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { ChatMessage, QueryEngineTool } from "llamaindex";
import { getDataSource } from "../engine";
import { FunctionCallingAgent } from "./single-agent";

const getQueryEngineTool = async () => {
const index = await getDataSource();
if (!index) {
throw new Error(
"StorageContext is empty - call 'npm run generate' to generate the storage first.",
);
}

const topK = process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined;
return new QueryEngineTool({
queryEngine: index.asQueryEngine({
similarityTopK: topK,
}),
metadata: {
name: "query_index",
description: `Use this tool to retrieve information about the text corpus from the index.`,
},
});
};

export const createResearcher = async (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "researcher",
tools: [await getQueryEngineTool()],
systemPrompt:
"You are a researcher agent. You are given a researching task. You must use your tools to complete the research.",
chatHistory,
});
};

export const createWriter = (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "writer",
systemPrompt:
"You are an expert in writing blog posts. You are given a task to write a blog post. Don't make up any information yourself.",
chatHistory,
});
};

export const createReviewer = (chatHistory: ChatMessage[]) => {
return new FunctionCallingAgent({
name: "reviewer",
systemPrompt:
"You are an expert in reviewing blog posts. You are given a task to review a blog post. Review the post for logical inconsistencies, ask critical questions, and provide suggestions for improvement. Furthermore, proofread the post for grammar and spelling errors. Only if the post is good enough for publishing, then you MUST return 'The post is good.'. In all other cases return your review.",
chatHistory,
});
};
133 changes: 133 additions & 0 deletions templates/components/multiagent/typescript/workflow/factory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import {
Context,
StartEvent,
StopEvent,
Workflow,
WorkflowEvent,
} from "@llamaindex/core/workflow";
import { ChatMessage, ChatResponseChunk } from "llamaindex";
import { createResearcher, createReviewer, createWriter } from "./agents";
import { AgentInput, AgentRunEvent } from "./type";

const TIMEOUT = 360 * 1000;
const MAX_ATTEMPTS = 2;

class ResearchEvent extends WorkflowEvent<{ input: string }> {}
class WriteEvent extends WorkflowEvent<{
input: string;
isGood: boolean;
}> {}
class ReviewEvent extends WorkflowEvent<{ input: string }> {}

export const createWorkflow = (chatHistory: ChatMessage[]) => {
const runAgent = async (
context: Context,
agent: Workflow,
input: AgentInput,
) => {
const run = agent.run(new StartEvent({ input }));
for await (const event of agent.streamEvents()) {
if (event.data instanceof AgentRunEvent) {
context.writeEventToStream(event.data);
}
}
return await run;
};

const start = async (context: Context, ev: StartEvent) => {
context.set("task", ev.data.input);
return new ResearchEvent({
input: `Research for this task: ${ev.data.input}`,
});
};

const research = async (context: Context, ev: ResearchEvent) => {
const researcher = await createResearcher(chatHistory);
const researchRes = await runAgent(context, researcher, {
message: ev.data.input,
});
const researchResult = researchRes.data.result;
return new WriteEvent({
input: `Write a blog post given this task: ${context.get("task")} using this research content: ${researchResult}`,
isGood: false,
});
};

const write = async (context: Context, ev: WriteEvent) => {
context.set("attempts", context.get("attempts", 0) + 1);
const tooManyAttempts = context.get("attempts") > MAX_ATTEMPTS;
if (tooManyAttempts) {
context.writeEventToStream(
new AgentRunEvent({
name: "writer",
msg: `Too many attempts (${MAX_ATTEMPTS}) to write the blog post. Proceeding with the current version.`,
}),
);
}

if (ev.data.isGood || tooManyAttempts) {
// The text is ready for publication, we just use the writer to stream the output
const writer = createWriter(chatHistory);
const content = context.get("result");

return (await runAgent(context, writer, {
message: `You're blog post is ready for publication. Please respond with just the blog post. Blog post: \`\`\`${content}\`\`\``,
streaming: true,
})) as unknown as StopEvent<AsyncGenerator<ChatResponseChunk>>;
}

const writer = createWriter(chatHistory);
const writeRes = await runAgent(context, writer, {
message: ev.data.input,
});
const writeResult = writeRes.data.result;
context.set("result", writeResult); // store the last result
return new ReviewEvent({ input: writeResult });
};

const review = async (context: Context, ev: ReviewEvent) => {
const reviewer = createReviewer(chatHistory);
const reviewRes = await reviewer.run(
new StartEvent<AgentInput>({ input: { message: ev.data.input } }),
);
const reviewResult = reviewRes.data.result;
const oldContent = context.get("result");
const postIsGood = reviewResult.toLowerCase().includes("post is good");
context.writeEventToStream(
new AgentRunEvent({
name: "reviewer",
msg: `The post is ${postIsGood ? "" : "not "}good enough for publishing. Sending back to the writer${
postIsGood ? " for publication." : "."
}`,
}),
);
if (postIsGood) {
return new WriteEvent({
input: "",
isGood: true,
});
}

return new WriteEvent({
input: `Improve the writing of a given blog post by using a given review.
Blog post:
\`\`\`
${oldContent}
\`\`\`
Review:
\`\`\`
${reviewResult}
\`\`\``,
isGood: false,
});
};

const workflow = new Workflow({ timeout: TIMEOUT, validate: true });
workflow.addStep(StartEvent, start, { outputs: ResearchEvent });
workflow.addStep(ResearchEvent, research, { outputs: WriteEvent });
workflow.addStep(WriteEvent, write, { outputs: [ReviewEvent, StopEvent] });
workflow.addStep(ReviewEvent, review, { outputs: WriteEvent });

return workflow;
};
Loading

0 comments on commit ef070c0

Please sign in to comment.