Skip to content

Commit

Permalink
fix: Vercel useAssistant BranchPicker duplicates bug (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jun 29, 2024
1 parent 4057cb4 commit be2c26b
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 68 deletions.
6 changes: 6 additions & 0 deletions .changeset/nice-baboons-peel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

fix: Vercel useAssistant BranchPicker duplicates bug
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { type ThreadMessage, INTERNAL } from "@assistant-ui/react";
import { ModelConfigProvider } from "@assistant-ui/react";
import { UseAssistantHelpers } from "@ai-sdk/react";
import { VercelUseAssistantThreadRuntime } from "./VercelUseAssistantThreadRuntime";

const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL;

export const hasUpcomingMessage = (
isRunning: boolean,
messages: ThreadMessage[],
) => {
return isRunning && messages[messages.length - 1]?.role !== "assistant";
};

export class VercelUseAssistantRuntime extends BaseAssistantRuntime<VercelUseAssistantThreadRuntime> {
private readonly _proxyConfigProvider = new ProxyConfigProvider();

constructor(vercel: UseAssistantHelpers) {
super(new VercelUseAssistantThreadRuntime(vercel));
}

public set vercel(vercel: UseAssistantHelpers) {
this.thread.vercel = vercel;
}

public onVercelUpdated() {
return this.thread.onVercelUpdated();
}

public getModelConfig() {
return this._proxyConfigProvider.getModelConfig();
}

public registerModelConfigProvider(provider: ModelConfigProvider) {
return this._proxyConfigProvider.registerModelConfigProvider(provider);
}

public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error("VercelAIRuntime does not yet support switching threads");
}

// clear the vercel state (otherwise, it will be captured by the MessageRepository)
this.thread.vercel.messages = [];
this.thread.vercel.input = "";
this.thread.vercel.setMessages([]);
this.thread.vercel.setInput("");

this.thread = new VercelUseAssistantThreadRuntime(this.thread.vercel);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import {
type ReactThreadRuntime,
type Unsubscribe,
type AppendMessage,
type ThreadMessage,
} from "@assistant-ui/react";
import { type StoreApi, type UseBoundStore, create } from "zustand";
import { useVercelAIComposerSync } from "../utils/useVercelAIComposerSync";
import { useVercelAIThreadSync } from "../utils/useVercelAIThreadSync";
import { UseAssistantHelpers } from "@ai-sdk/react";
import { hasUpcomingMessage } from "./VercelUseAssistantRuntime";

const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);

export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime {
private _subscriptions = new Set<() => void>();

private useVercel: UseBoundStore<StoreApi<{ vercel: UseAssistantHelpers }>>;

public messages: readonly ThreadMessage[] = [];
public isRunning = false;

constructor(public vercel: UseAssistantHelpers) {
this.useVercel = create(() => ({
vercel,
}));
}

public getBranches(): readonly string[] {
return EMPTY_BRANCHES;
}

public switchToBranch(): void {
throw new Error(
"VercelUseAssistantRuntime does not support switching branches.",
);
}

public async append(message: AppendMessage): Promise<void> {
// add user message
if (message.content.length !== 1 || message.content[0]?.type !== "text")
throw new Error("VercelUseAssistantRuntime only supports text content.");

if (message.parentId !== (this.messages.at(-1)?.id ?? null))
throw new Error(
"VercelUseAssistantRuntime does not support editing messages.",
);

await this.vercel.append({
role: "user",
content: message.content[0].text,
});
}

public async startRun(): Promise<void> {
throw new Error("VercelUseAssistantRuntime does not support reloading.");
}

public cancelRun(): void {
const previousMessage = this.vercel.messages.at(-1);

this.vercel.stop();
if (previousMessage?.role === "user") {
this.vercel.setInput(previousMessage.content);
}
}

public subscribe(callback: () => void): Unsubscribe {
this._subscriptions.add(callback);
return () => this._subscriptions.delete(callback);
}

public onVercelUpdated() {
if (this.useVercel.getState().vercel !== this.vercel) {
this.useVercel.setState({ vercel: this.vercel });
}
}

private updateData = (isRunning: boolean, vm: ThreadMessage[]) => {
if (hasUpcomingMessage(isRunning, vm)) {
vm.push({
id: "__optimistic__result",
createdAt: new Date(),
status: "in_progress",
role: "assistant",
content: [{ type: "text", text: "" }],
});
}

this.messages = vm;
this.isRunning = isRunning;

for (const callback of this._subscriptions) callback();
};

unstable_synchronizer = () => {
const { vercel } = this.useVercel();

useVercelAIThreadSync(vercel, this.updateData);
useVercelAIComposerSync(vercel);

return null;
};

addToolResult() {
throw new Error(
"VercelUseAssistantRuntime does not support adding tool results.",
);
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import type { UseAssistantHelpers } from "@ai-sdk/react";
import { useEffect, useInsertionEffect, useState } from "react";
import { VercelAIRuntime } from "../VercelAIRuntime";
import { VercelUseAssistantRuntime } from "./VercelUseAssistantRuntime";

export const useVercelUseAssistantRuntime = (
assistantHelpers: UseAssistantHelpers,
) => {
const [runtime] = useState(() => new VercelAIRuntime(assistantHelpers));
const [runtime] = useState(
() => new VercelUseAssistantRuntime(assistantHelpers),
);

useInsertionEffect(() => {
runtime.vercel = assistantHelpers;
Expand Down
46 changes: 46 additions & 0 deletions packages/react-ai-sdk/src/ui/use-chat/VercelUseChatRuntime.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { INTERNAL } from "@assistant-ui/react";
import { ModelConfigProvider } from "@assistant-ui/react";
import { useChat } from "@ai-sdk/react";
import { VercelUseChatThreadRuntime } from "./VercelUseChatThreadRuntime";

const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL;

export class VercelUseChatRuntime extends BaseAssistantRuntime<VercelUseChatThreadRuntime> {
private readonly _proxyConfigProvider = new ProxyConfigProvider();

constructor(vercel: ReturnType<typeof useChat>) {
super(new VercelUseChatThreadRuntime(vercel));
}

public set vercel(vercel: ReturnType<typeof useChat>) {
this.thread.vercel = vercel;
}

public onVercelUpdated() {
return this.thread.onVercelUpdated();
}

public getModelConfig() {
return this._proxyConfigProvider.getModelConfig();
}

public registerModelConfigProvider(provider: ModelConfigProvider) {
return this._proxyConfigProvider.registerModelConfigProvider(provider);
}

public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error(
"VercelAIRuntime does not yet support switching threads.",
);
}

// clear the vercel state (otherwise, it will be captured by the MessageRepository)
this.thread.vercel.messages = [];
this.thread.vercel.input = "";
this.thread.vercel.setMessages([]);
this.thread.vercel.setInput("");

this.thread = new VercelUseChatThreadRuntime(this.thread.vercel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,69 +7,34 @@ import {
} from "@assistant-ui/react";
import type { Message } from "ai";
import { type StoreApi, type UseBoundStore, create } from "zustand";
import { getVercelAIMessage } from "./getVercelAIMessage";
import type { VercelHelpers } from "./utils/VercelHelpers";
import { sliceMessagesUntil } from "./utils/sliceMessagesUntil";
import { useVercelAIComposerSync } from "./utils/useVercelAIComposerSync";
import { useVercelAIThreadSync } from "./utils/useVercelAIThreadSync";
import { ModelConfigProvider } from "@assistant-ui/react";

const { ProxyConfigProvider, MessageRepository, BaseAssistantRuntime } =
INTERNAL;

const hasUpcomingMessage = (isRunning: boolean, messages: ThreadMessage[]) => {
import { getVercelAIMessage } from "../getVercelAIMessage";
import { sliceMessagesUntil } from "../utils/sliceMessagesUntil";
import { useVercelAIComposerSync } from "../utils/useVercelAIComposerSync";
import { useVercelAIThreadSync } from "../utils/useVercelAIThreadSync";
import { useChat } from "@ai-sdk/react";

const { MessageRepository } = INTERNAL;

export const hasUpcomingMessage = (
isRunning: boolean,
messages: ThreadMessage[],
) => {
return isRunning && messages[messages.length - 1]?.role !== "assistant";
};

export class VercelAIRuntime extends BaseAssistantRuntime<VercelAIThreadRuntime> {
private readonly _proxyConfigProvider = new ProxyConfigProvider();

constructor(vercel: VercelHelpers) {
super(new VercelAIThreadRuntime(vercel));
}

public set vercel(vercel: VercelHelpers) {
this.thread.vercel = vercel;
}

public onVercelUpdated() {
return this.thread.onVercelUpdated();
}

public getModelConfig() {
return this._proxyConfigProvider.getModelConfig();
}

public registerModelConfigProvider(provider: ModelConfigProvider) {
return this._proxyConfigProvider.registerModelConfigProvider(provider);
}

public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error("VercelAIRuntime does not yet support switching threads");
}

// clear the vercel state (otherwise, it will be captured by the MessageRepository)
this.thread.vercel.messages = [];
this.thread.vercel.input = "";
this.thread.vercel.setMessages([]);
this.thread.vercel.setInput("");

this.thread = new VercelAIThreadRuntime(this.thread.vercel);
}
}

class VercelAIThreadRuntime implements ReactThreadRuntime {
export class VercelUseChatThreadRuntime implements ReactThreadRuntime {
private _subscriptions = new Set<() => void>();
private repository = new MessageRepository();
private assistantOptimisticId: string | null = null;

private useVercel: UseBoundStore<StoreApi<{ vercel: VercelHelpers }>>;
private useVercel: UseBoundStore<
StoreApi<{ vercel: ReturnType<typeof useChat> }>
>;

public messages: ThreadMessage[] = [];
public isRunning = false;

constructor(public vercel: VercelHelpers) {
constructor(public vercel: ReturnType<typeof useChat>) {
this.useVercel = create(() => ({
vercel,
}));
Expand All @@ -87,7 +52,9 @@ class VercelAIThreadRuntime implements ReactThreadRuntime {
public async append(message: AppendMessage): Promise<void> {
// add user message
if (message.content.length !== 1 || message.content[0]?.type !== "text")
throw new Error("Only text content is supported by Vercel AI SDK.");
throw new Error(
"Only text content is supported by VercelUseChatRuntime.",
);

const newMessages = sliceMessagesUntil(
this.vercel.messages,
Expand All @@ -102,17 +69,10 @@ class VercelAIThreadRuntime implements ReactThreadRuntime {
}

public async startRun(parentId: string | null): Promise<void> {
const reloadMaybe =
"reload" in this.vercel ? this.vercel.reload : undefined;
if (!reloadMaybe)
throw new Error(
"Reload is not supported by Vercel AI SDK's useAssistant.",
);

const newMessages = sliceMessagesUntil(this.vercel.messages, parentId);
this.vercel.setMessages(newMessages);

await reloadMaybe();
await this.vercel.reload();
}

public cancelRun(): void {
Expand Down Expand Up @@ -203,9 +163,6 @@ class VercelAIThreadRuntime implements ReactThreadRuntime {
};

addToolResult(toolCallId: string, result: any) {
if (!("addToolResult" in this.vercel)) {
throw new Error("VercelAIRuntime does not support adding tool results");
}
this.vercel.addToolResult({ toolCallId, result });
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import type { useChat } from "@ai-sdk/react";
import { useEffect, useInsertionEffect, useState } from "react";
import { VercelAIRuntime } from "../VercelAIRuntime";
import { VercelUseChatRuntime } from "./VercelUseChatRuntime";

export const useVercelUseChatRuntime = (
chatHelpers: ReturnType<typeof useChat>,
) => {
const [runtime] = useState(() => new VercelAIRuntime(chatHelpers));
const [runtime] = useState(() => new VercelUseChatRuntime(chatHelpers));

useInsertionEffect(() => {
runtime.vercel = chatHelpers;
Expand Down

0 comments on commit be2c26b

Please sign in to comment.