Skip to content

Commit

Permalink
feat: edge runtime maxToolRoundtrips support (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 13, 2024
1 parent c63fd8c commit 26244cc
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 31 deletions.
55 changes: 51 additions & 4 deletions packages/react/src/runtimes/edge/EdgeChatAdapter.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ThreadAssistantContentPart, ThreadMessage } from "../../types";
import { ChatModelAdapter, ChatModelRunOptions } from "../local";
import { ChatModelRunResult } from "../local/ChatModelAdapter";
import { toCoreMessages } from "./converters/toCoreMessages";
Expand Down Expand Up @@ -27,12 +28,16 @@ export function asAsyncIterable<T>(
}
export type EdgeChatAdapterOptions = {
api: string;
maxToolRoundtrips?: number;
};

export class EdgeChatAdapter implements ChatModelAdapter {
constructor(private options: EdgeChatAdapterOptions) {}

async run({ messages, abortSignal, config, onUpdate }: ChatModelRunOptions) {
async roundtrip(
initialContent: ThreadAssistantContentPart[],
{ messages, abortSignal, config, onUpdate }: ChatModelRunOptions,
) {
const result = await fetch(this.options.api, {
method: "POST",
headers: {
Expand All @@ -53,14 +58,56 @@ export class EdgeChatAdapter implements ChatModelAdapter {
.pipeThrough(chunkByLineStream())
.pipeThrough(assistantDecoderStream())
.pipeThrough(toolResultStream(config.tools))
.pipeThrough(runResultStream());
.pipeThrough(runResultStream(initialContent));

let message: ThreadMessage | undefined;
let update: ChatModelRunResult | undefined;
for await (update of asAsyncIterable(stream)) {
onUpdate(update);
message = onUpdate(update);
}
if (update === undefined)
throw new Error("No data received from Edge Runtime");
return update;

return [message, update] as const;
}

async run({ messages, abortSignal, config, onUpdate }: ChatModelRunOptions) {
let roundtripAllowance = this.options.maxToolRoundtrips ?? 1;
let usage = {
promptTokens: 0,
completionTokens: 0,
};
let result;
let assistantMessage;
do {
[assistantMessage, result] = await this.roundtrip(result?.content ?? [], {
messages: assistantMessage ? [...messages, assistantMessage] : messages,
abortSignal,
config,
onUpdate,
});
if (result.status?.type === "done") {
usage.promptTokens += result.status.usage?.promptTokens ?? 0;
usage.completionTokens += result.status.usage?.completionTokens ?? 0;
}
} while (
result.status?.type === "done" &&
result.status.finishReason === "tool-calls" &&
result.content.every((c) => c.type !== "tool-call" || !!c.result) &&
roundtripAllowance-- > 0
);

// add usage across all roundtrips
if (result.status?.type === "done" && usage.promptTokens > 0) {
result = {
...result,
status: {
...result.status,
usage,
},
};
}

return result;
}
}
48 changes: 25 additions & 23 deletions packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,31 @@ export const createEdgeRuntimeAPI = ({
let serverStream = tees[1];

if (onFinish) {
serverStream = serverStream.pipeThrough(runResultStream()).pipeThrough(
new TransformStream({
transform(chunk) {
if (chunk.status?.type !== "done") return;
const resultingMessages = [
...messages,
{
role: "assistant",
content: chunk.content,
} as CoreAssistantMessage,
];
onFinish({
finishReason: chunk.status.finishReason!,
usage: chunk.status.usage!,
messages: resultingMessages,
logProbs: chunk.status.logprops,
warnings: streamResult.warnings,
rawCall: streamResult.rawCall,
rawResponse: streamResult.rawResponse,
});
},
}),
);
serverStream = serverStream
.pipeThrough(runResultStream([]))
.pipeThrough(
new TransformStream({
transform(chunk) {
if (chunk.status?.type !== "done") return;
const resultingMessages = [
...messages,
{
role: "assistant",
content: chunk.content,
} as CoreAssistantMessage,
];
onFinish({
finishReason: chunk.status.finishReason!,
usage: chunk.status.usage!,
messages: resultingMessages,
logProbs: chunk.status.logprops,
warnings: streamResult.warnings,
rawCall: streamResult.rawCall,
rawResponse: streamResult.rawResponse,
});
},
}),
);
}

// drain the server stream
Expand Down
5 changes: 3 additions & 2 deletions packages/react/src/runtimes/edge/streams/runResultStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ import { ChatModelRunResult } from "../../local/ChatModelAdapter";
import { parsePartialJson } from "../partial-json/parse-partial-json";
import { LanguageModelV1StreamPart } from "@ai-sdk/provider";
import { ToolResultStreamPart } from "./toolResultStream";
import { ThreadAssistantContentPart } from "../../../types";

export function runResultStream() {
export function runResultStream(initialContent: ThreadAssistantContentPart[]) {
let message: ChatModelRunResult = {
content: [],
content: initialContent,
};
const currentToolCall = { toolCallId: "", argsText: "" };

Expand Down
2 changes: 1 addition & 1 deletion packages/react/src/runtimes/local/ChatModelAdapter.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export type ChatModelRunOptions = {
messages: ThreadMessage[];
abortSignal: AbortSignal;
config: ModelConfig;
onUpdate: (result: ChatModelRunUpdate) => void;
onUpdate: (result: ChatModelRunUpdate) => ThreadMessage;
};

export type ChatModelAdapter = {
Expand Down
4 changes: 3 additions & 1 deletion packages/react/src/runtimes/local/LocalRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ class LocalThreadRuntime implements ThreadRuntime {
try {
const updateHandler = ({ content }: ChatModelRunResult) => {
message.content = content;
this.repository.addOrUpdateMessage(parentId, { ...message });
const newMessage = { ...message };
this.repository.addOrUpdateMessage(parentId, newMessage);
this.notifySubscribers();
return newMessage;
};
const result = await this.adapter.run({
messages,
Expand Down

0 comments on commit 26244cc

Please sign in to comment.