Skip to content

Commit

Permalink
feat: Edge runtime finish reason, logprobs and token count streaming (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 11, 2024
1 parent 5bcca69 commit a7a923f
Showing 8 changed files with 83 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -3,4 +3,5 @@ export enum AssistantStreamChunkType {
ToolCallBegin = "1",
ToolCallArgsTextDelta = "2",
Error = "E",
Finish = "F",
}
17 changes: 17 additions & 0 deletions packages/react/src/runtimes/edge/streams/AssistantStreamPart.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
import {
LanguageModelV1FinishReason,
LanguageModelV1LogProbs,
} from "@ai-sdk/provider";

export type AssistantStreamFinishPart = {
type: "finish";
finishReason: LanguageModelV1FinishReason;
logprops?: LanguageModelV1LogProbs;
usage: {
promptTokens: number;
completionTokens: number;
};
};

export type AssistantStreamPart =
| {
type: "text-delta";
textDelta: string;
}
| {
type: "tool-call-delta";
toolCallType: "function";
toolCallId: string;
toolName: string;
argsTextDelta: string;
}
| AssistantStreamFinishPart
| {
type: "error";
error: unknown;
Original file line number Diff line number Diff line change
@@ -27,12 +27,20 @@ export function assistantDecoderStream() {
const delta = JSON.parse(value);
controller.enqueue({
type: "tool-call-delta",
toolCallType: "function",
toolCallId: currentToolCall!.id,
toolName: currentToolCall!.name,
argsTextDelta: delta,
});
break;
}
case AssistantStreamChunkType.Finish: {
controller.enqueue({
type: "finish",
...JSON.parse(value),
});
break;
}
case AssistantStreamChunkType.Error: {
controller.enqueue({
type: "error",
12 changes: 12 additions & 0 deletions packages/react/src/runtimes/edge/streams/assistantEncoderStream.ts
Original file line number Diff line number Diff line change
@@ -35,6 +35,18 @@ export function assistantEncoderStream() {
);
break;
}

case "finish": {
const { type, ...rest } = chunk;
controller.enqueue(
formatStreamPart(
AssistantStreamChunkType.Finish,
JSON.stringify(rest),
),
);
break;
}

case "error": {
controller.enqueue(
formatStreamPart(
11 changes: 3 additions & 8 deletions packages/react/src/runtimes/edge/streams/assistantStream.ts
Original file line number Diff line number Diff line change
@@ -9,19 +9,14 @@ export function assistantStream() {
switch (chunkType) {
// forward
case "text-delta":
case "error": {
case "error":
case "tool-call-delta":
case "finish": {
controller.enqueue(chunk);
break;
}

case "tool-call-delta": {
const { toolCallType, ...rest } = chunk;
controller.enqueue(rest);
break;
}

// ignore
case "finish":
case "tool-call": {
break;
}
24 changes: 23 additions & 1 deletion packages/react/src/runtimes/edge/streams/runResultStream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { AssistantStreamPart } from "./AssistantStreamPart";
import {
AssistantStreamFinishPart,
AssistantStreamPart,
} from "./AssistantStreamPart";
import { ChatModelRunResult } from "../../local/ChatModelAdapter";

export function runResultStream() {
@@ -34,6 +37,11 @@ export function runResultStream() {
controller.enqueue(message);
break;
}
case "finish": {
message = appendOrUpdateFinish(message, chunk);
controller.enqueue(message);
break;
}
case "error": {
throw chunk.error;
}
@@ -91,3 +99,17 @@ const appendOrUpdateToolCall = (
content: contentParts.concat([contentPart]),
};
};

const appendOrUpdateFinish = (
message: ChatModelRunResult,
chunk: AssistantStreamFinishPart,
): ChatModelRunResult => {
const { type, ...rest } = chunk;
return {
...message,
status: {
type: "done",
...rest,
},
};
};
10 changes: 8 additions & 2 deletions packages/react/src/runtimes/local/ChatModelAdapter.tsx
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"use client";
import type {
MessageStatus,
ThreadAssistantContentPart,
ThreadMessage,
} from "../../types/AssistantTypes";
import type { ModelConfig } from "../../types/ModelConfigTypes";

export type ChatModelRunUpdate = {
content: ThreadAssistantContentPart[];
};

export type ChatModelRunResult = {
content: ThreadAssistantContentPart[];
status?: MessageStatus;
};

export type ChatModelRunOptions = {
messages: ThreadMessage[];
abortSignal: AbortSignal;
config: ModelConfig;
onUpdate: (result: ChatModelRunResult) => void;
onUpdate: (result: ChatModelRunUpdate) => void;
};

export type ChatModelAdapter = {
run: (options: ChatModelRunOptions) => Promise<ChatModelRunResult | void>;
run: (options: ChatModelRunOptions) => Promise<ChatModelRunResult>;
};
13 changes: 11 additions & 2 deletions packages/react/src/types/AssistantTypes.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { LanguageModelV1FinishReason, LanguageModelV1LogProbs } from "@ai-sdk/provider";
import type { ReactNode } from "react";

export type TextContentPart = {
@@ -42,8 +43,16 @@ type MessageCommonProps = {

export type MessageStatus =
| {
type: "in_progress" | "done";
error?: undefined;
type: "in_progress";
}
| {
type: "done";
finishReason?: LanguageModelV1FinishReason;
logprops?: LanguageModelV1LogProbs;
usage?: {
promptTokens: number;
completionTokens: number;
};
}
| {
type: "error";

0 comments on commit a7a923f

Please sign in to comment.