Skip to content

Commit

Permalink
feat: Edge runtime serverside tools (#487)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 13, 2024
1 parent 619edbd commit 0ec4695
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 36 deletions.
192 changes: 161 additions & 31 deletions packages/react/src/runtimes/edge/createEdgeRuntimeAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,183 @@ import {
LanguageModelV1Prompt,
LanguageModelV1CallOptions,
LanguageModelV1CallWarning,
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
} from "@ai-sdk/provider";
import { CoreMessage } from "../../types/AssistantTypes";
import { CoreAssistantMessage, CoreMessage } from "../../types/AssistantTypes";
import { assistantEncoderStream } from "./streams/assistantEncoderStream";
import { EdgeRuntimeRequestOptions } from "./EdgeRuntimeRequestOptions";
import { toLanguageModelMessages } from "./converters/toLanguageModelMessages";
import { z } from "zod";
import { Tool } from "../../types";
import { toLanguageModelTools } from "./converters/toLanguageModelTools";
import {
toolResultStream,
ToolResultStreamPart,
} from "./streams/toolResultStream";
import { runResultStream } from "./streams/runResultStream";

const LanguageModelSettingsSchema = z.object({
maxTokens: z.number().int().positive().optional(),
temperature: z.number().optional(),
topP: z.number().optional(),
presencePenalty: z.number().optional(),
frequencyPenalty: z.number().optional(),
seed: z.number().int().optional(),
headers: z.record(z.string().optional()).optional(),
});

type LanguageModelSettings = z.infer<typeof LanguageModelSettingsSchema>;

type FinishResult = {
finishReason: LanguageModelV1FinishReason;
usage: {
promptTokens: number;
completionTokens: number;
};
logProbs?: LanguageModelV1LogProbs | undefined;
messages: CoreMessage[];
rawCall: {
rawPrompt: unknown;
rawSettings: Record<string, unknown>;
};
warnings?: LanguageModelV1CallWarning[] | undefined;
rawResponse?:
| {
headers?: Record<string, string>;
}
| undefined;
};

type CreateEdgeRuntimeAPIOptions = LanguageModelSettings & {
model: LanguageModelV1;
system?: string;
tools?: Record<string, Tool<any, any>>;
toolChoice?: LanguageModelV1ToolChoice;
onFinish?: (result: FinishResult) => void;
};

const voidStream = () => {
return new WritableStream({
abort(reason) {
console.error("Server stream processing aborted:", reason);
},
});
};

export const createEdgeRuntimeAPI = ({
model,
system: serverSystem,
tools: serverTools = {},
toolChoice,
onFinish,
...unsafeSettings
}: CreateEdgeRuntimeAPIOptions) => {
const settings = LanguageModelSettingsSchema.parse(unsafeSettings);
const lmServerTools = toLanguageModelTools(serverTools);
const hasServerTools = Object.values(serverTools).some((v) => !!v.execute);

export const createEdgeRuntimeAPI = ({ model }: { model: LanguageModelV1 }) => {
const POST = async (request: Request) => {
const { system, messages, tools } =
(await request.json()) as EdgeRuntimeRequestOptions;
const {
system: clientSystem,
tools: clientTools,
messages,
} = (await request.json()) as EdgeRuntimeRequestOptions;

const systemMessages = [];
if (serverSystem) systemMessages.push(serverSystem);
if (clientSystem) systemMessages.push(clientSystem);
const system = systemMessages.join("\n\n");

for (const clientTool of clientTools) {
if (serverTools?.[clientTool.name]) {
throw new Error(
`Tool ${clientTool.name} was defined in both the client and server tools. This is not allowed.`,
);
}
}

let stream: ReadableStream<ToolResultStreamPart>;
const streamResult = await streamMessage({
...(settings as Partial<StreamMessageOptions>),

const { stream } = await streamMessage({
model,
abortSignal: request.signal,

...(system ? { system } : undefined),
...(!!system ? { system } : undefined),
messages,
tools,
tools: lmServerTools.concat(clientTools),
...(toolChoice ? { toolChoice } : undefined),
});
stream = streamResult.stream;

// add tool results if we have server tools
const canExecuteTools = hasServerTools && toolChoice?.type !== "none";
if (canExecuteTools) {
stream = stream.pipeThrough(toolResultStream(serverTools));
}

if (canExecuteTools || onFinish) {
// tee the stream to process server tools and onFinish asap
const tees = stream.tee();
stream = tees[0];
let serverStream = tees[1];

if (onFinish) {
serverStream = serverStream.pipeThrough(runResultStream()).pipeThrough(
new TransformStream({
transform(chunk) {
if (chunk.status?.type !== "done") return;
const resultingMessages = [
...messages,
{
role: "assistant",
content: chunk.content,
} as CoreAssistantMessage,
];
onFinish({
finishReason: chunk.status.finishReason!,
usage: chunk.status.usage!,
messages: resultingMessages,
logProbs: chunk.status.logprops,
warnings: streamResult.warnings,
rawCall: streamResult.rawCall,
rawResponse: streamResult.rawResponse,
});
},
}),
);
}

return new Response(stream, {
headers: {
contentType: "text/plain; charset=utf-8",
// drain the server stream
serverStream.pipeTo(voidStream()).catch((e) => {
console.error("Server stream processing error:", e);
});
}

return new Response(
stream
.pipeThrough(assistantEncoderStream())
.pipeThrough(new TextEncoderStream()),
{
headers: {
contentType: "text/plain; charset=utf-8",
},
},
});
);
};
return { POST };
};

type StreamMessageResult = {
stream: ReadableStream<Uint8Array>;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse: unknown;
type StreamMessageOptions = Omit<
LanguageModelV1CallOptions,
"inputFormat" | "mode" | "prompt"
> & {
model: LanguageModelV1;
system?: string;
messages: CoreMessage[];
tools?: LanguageModelV1FunctionTool[];
toolChoice?: LanguageModelV1ToolChoice;
};

async function streamMessage({
Expand All @@ -47,14 +191,8 @@ async function streamMessage({
tools,
toolChoice,
...options
}: Omit<LanguageModelV1CallOptions, "inputFormat" | "mode" | "prompt"> & {
model: LanguageModelV1;
system?: string;
messages: CoreMessage[];
tools?: LanguageModelV1FunctionTool[];
toolChoice?: LanguageModelV1ToolChoice;
}): Promise<StreamMessageResult> {
const { stream, warnings, rawResponse } = await model.doStream({
}: StreamMessageOptions) {
return model.doStream({
inputFormat: "messages",
mode: {
type: "regular",
Expand All @@ -64,14 +202,6 @@ async function streamMessage({
prompt: convertToLanguageModelPrompt(system, messages),
...options,
});

return {
stream: stream
.pipeThrough(assistantEncoderStream())
.pipeThrough(new TextEncoderStream()),
warnings,
rawResponse,
};
}

export function convertToLanguageModelPrompt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export enum AssistantStreamChunkType {
TextDelta = "0",
ToolCallBegin = "1",
ToolCallArgsTextDelta = "2",
ToolCallResult = "3",
Error = "E",
Finish = "F",
}
Expand All @@ -18,6 +19,13 @@ export type AssistantStreamChunkTuple =
},
]
| [AssistantStreamChunkType.ToolCallArgsTextDelta, string]
| [
AssistantStreamChunkType.ToolCallResult,
{
id: string;
result: any;
},
]
| [AssistantStreamChunkType.Error, unknown]
| [
AssistantStreamChunkType.Finish,
Expand Down
16 changes: 14 additions & 2 deletions packages/react/src/runtimes/edge/streams/assistantDecoderStream.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import { LanguageModelV1StreamPart } from "@ai-sdk/provider";
import {
AssistantStreamChunkTuple,
AssistantStreamChunkType,
} from "./AssistantStreamChunkType";
import { ToolResultStreamPart } from "./toolResultStream";

export function assistantDecoderStream() {
const toolCallNames = new Map<string, string>();
let currentToolCall:
| { id: string; name: string; argsText: string }
| undefined;

return new TransformStream<string, LanguageModelV1StreamPart>({
return new TransformStream<string, ToolResultStreamPart>({
transform(chunk, controller) {
const [code, value] = parseStreamPart(chunk);

Expand Down Expand Up @@ -38,6 +39,7 @@ export function assistantDecoderStream() {
}
case AssistantStreamChunkType.ToolCallBegin: {
const { id, name } = value;
toolCallNames.set(id, name);
currentToolCall = { id, name, argsText: "" };
break;
}
Expand All @@ -53,6 +55,16 @@ export function assistantDecoderStream() {
});
break;
}
case AssistantStreamChunkType.ToolCallResult: {
controller.enqueue({
type: "tool-result",
toolCallType: "function",
toolCallId: value.id,
toolName: toolCallNames.get(value.id)!,
result: value.result,
});
break;
}
case AssistantStreamChunkType.Finish: {
controller.enqueue({
type: "finish",
Expand Down
14 changes: 12 additions & 2 deletions packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ import {
AssistantStreamChunkTuple,
AssistantStreamChunkType,
} from "./AssistantStreamChunkType";
import { LanguageModelV1StreamPart } from "@ai-sdk/provider";
import { ToolResultStreamPart } from "./toolResultStream";

export function assistantEncoderStream() {
const toolCalls = new Set<string>();
return new TransformStream<LanguageModelV1StreamPart, string>({
return new TransformStream<ToolResultStreamPart, string>({
transform(chunk, controller) {
const chunkType = chunk.type;
switch (chunkType) {
Expand Down Expand Up @@ -43,6 +43,16 @@ export function assistantEncoderStream() {
case "tool-call":
break;

case "tool-result": {
controller.enqueue(
formatStreamPart(AssistantStreamChunkType.ToolCallResult, {
id: chunk.toolCallId,
result: chunk.result,
}),
);
break;
}

case "finish": {
const { type, ...rest } = chunk;
controller.enqueue(
Expand Down
3 changes: 2 additions & 1 deletion packages/react/src/runtimes/edge/streams/toolResultStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export type ToolResultStreamPart =
export function toolResultStream(tools: Record<string, Tool> | undefined) {
const toolCallExecutions = new Map<string, Promise<any>>();

return new TransformStream<LanguageModelV1StreamPart, ToolResultStreamPart>({
return new TransformStream<ToolResultStreamPart, ToolResultStreamPart>({
transform(chunk, controller) {
// forward everything
controller.enqueue(chunk);
Expand Down Expand Up @@ -70,6 +70,7 @@ export function toolResultStream(tools: Record<string, Tool> | undefined) {
// ignore other parts
case "text-delta":
case "tool-call-delta":
case "tool-result":
case "finish":
case "error":
break;
Expand Down

0 comments on commit 0ec4695

Please sign in to comment.