Skip to content

Commit

Permalink
feat: Runtime Capabilities API (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 4, 2024
1 parent 9610cd5 commit 05fd5d6
Show file tree
Hide file tree
Showing 18 changed files with 244 additions and 138 deletions.
7 changes: 7 additions & 0 deletions .changeset/clever-goats-learn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@assistant-ui/react": patch
"@assistant-ui/react-ui": patch
"@assistant-ui/react-ai-sdk": patch
---

feat: runtime capabilities API
21 changes: 16 additions & 5 deletions apps/www/components/docs/parameters/context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ export const ComposerState: ParametersTableProps = {
type: "ComposerState",
parameters: [
...BaseComposerState.parameters,
{
name: "canCancel",
type: "true",
required: true,
description: "Whether the composer can be canceled.",
},
{
name: "isEditing",
type: "true",
Expand All @@ -237,10 +243,9 @@ export const ComposerState: ParametersTableProps = {
},
{
name: "cancel",
type: "() => boolean",
type: "() => void",
required: true,
description:
"A function to cancel the run. Returns true if the run was canceled.",
description: "A function to cancel the run.",
},
{
name: "focus",
Expand All @@ -261,6 +266,12 @@ export const EditComposerState: ParametersTableProps = {
type: "EditComposerState",
parameters: [
...BaseComposerState.parameters,
{
name: "canCancel",
type: "boolean",
required: true,
description: "Whether the composer can be canceled.",
},
{
name: "isEditing",
type: "boolean",
Expand All @@ -281,9 +292,9 @@ export const EditComposerState: ParametersTableProps = {
},
{
name: "cancel",
type: "() => boolean",
type: "() => void",
required: true,
description: "A function to cancel the edit mode. Always returns true.",
description: "A function to exit the edit mode.",
},
],
};
Expand Down
109 changes: 6 additions & 103 deletions packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
"use client";

import {
INTERNAL,
type AppendMessage,
type ReactThreadRuntime,
type ThreadMessage,
type Unsubscribe,
} from "@assistant-ui/react";
import { type StoreApi, type UseBoundStore, create } from "zustand";
import { INTERNAL } from "@assistant-ui/react";
import type { VercelRSCAdapter } from "./VercelRSCAdapter";
import type { VercelRSCMessage } from "./VercelRSCMessage";
import { useVercelRSCSync } from "./useVercelRSCSync";
import { ModelConfigProvider } from "@assistant-ui/react";
import {
VercelRSCThreadRuntime,
ProxyConfigProvider,
} from "./VercelRSCThreadRuntime";

const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL;

const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);
const { BaseAssistantRuntime } = INTERNAL;

export class VercelRSCRuntime<
T extends WeakKey = VercelRSCMessage,
Expand Down Expand Up @@ -46,94 +40,3 @@ export class VercelRSCRuntime<
throw new Error("VercelRSCRuntime does not support switching threads");
}
}

class VercelRSCThreadRuntime<T extends WeakKey = VercelRSCMessage>
implements ReactThreadRuntime
{
private useAdapter: UseBoundStore<StoreApi<{ adapter: VercelRSCAdapter<T> }>>;

private _subscriptions = new Set<() => void>();

public isRunning = false;
public messages: ThreadMessage[] = [];

constructor(public adapter: VercelRSCAdapter<T>) {
this.useAdapter = create(() => ({
adapter,
}));
}

private withRunning = (callback: Promise<unknown>) => {
this.isRunning = true;
return callback.finally(() => {
this.isRunning = false;
});
};

public getBranches(): readonly string[] {
return EMPTY_BRANCHES;
}

public switchToBranch(): void {
throw new Error(
"Branch switching is not supported by VercelRSCAssistantProvider.",
);
}

public async append(message: AppendMessage): Promise<void> {
if (message.parentId !== (this.messages.at(-1)?.id ?? null)) {
if (!this.adapter.edit)
throw new Error(
"Message editing is not enabled, please provide an edit callback to VercelRSCAssistantProvider.",
);
await this.withRunning(this.adapter.edit(message));
} else {
await this.withRunning(this.adapter.append(message));
}
}

public async startRun(parentId: string | null): Promise<void> {
if (!this.adapter.reload)
throw new Error(
"Message reloading is not enabled, please provide a reload callback to VercelRSCAssistantProvider.",
);
await this.withRunning(this.adapter.reload(parentId));
}

cancelRun(): void {
// in dev mode, log a warning
if (process.env["NODE_ENV"] === "development") {
console.warn(
"Run cancellation is not supported by VercelRSCAssistantProvider.",
);
}
}

public subscribe(callback: () => void): Unsubscribe {
this._subscriptions.add(callback);
return () => this._subscriptions.delete(callback);
}

public onAdapterUpdated() {
if (this.useAdapter.getState().adapter !== this.adapter) {
this.useAdapter.setState({ adapter: this.adapter });
}
}

private updateData = (messages: ThreadMessage[]) => {
this.messages = messages;
for (const callback of this._subscriptions) callback();
};

unstable_synchronizer = () => {
const { adapter } = this.useAdapter();

useVercelRSCSync(adapter, this.updateData);

return null;
};

addToolResult() {
throw new Error("VercelRSCRuntime does not support adding tool results");
}
}
109 changes: 109 additions & 0 deletions packages/react-ai-sdk/src/rsc/VercelRSCThreadRuntime.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"use client";
import {
INTERNAL,
type AppendMessage,
type ReactThreadRuntime,
type ThreadMessage,
type Unsubscribe,
} from "@assistant-ui/react";
import { type StoreApi, type UseBoundStore, create } from "zustand";
import type { VercelRSCAdapter } from "./VercelRSCAdapter";
import type { VercelRSCMessage } from "./VercelRSCMessage";
import { useVercelRSCSync } from "./useVercelRSCSync";

export const { ProxyConfigProvider } = INTERNAL;
const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);
const CAPABILITIES = Object.freeze({
edit: false,
reload: false,
cancel: false,
copy: false,
});

export class VercelRSCThreadRuntime<T extends WeakKey = VercelRSCMessage>
implements ReactThreadRuntime
{
private useAdapter: UseBoundStore<StoreApi<{ adapter: VercelRSCAdapter<T> }>>;

public readonly capabilities = CAPABILITIES;

private _subscriptions = new Set<() => void>();

public isRunning = false;
public messages: ThreadMessage[] = [];

constructor(public adapter: VercelRSCAdapter<T>) {
this.useAdapter = create(() => ({
adapter,
}));
}

private withRunning = (callback: Promise<unknown>) => {
this.isRunning = true;
return callback.finally(() => {
this.isRunning = false;
});
};

public getBranches(): readonly string[] {
return EMPTY_BRANCHES;
}

public switchToBranch(): void {
throw new Error(
"Branch switching is not supported by VercelRSCAssistantProvider.",
);
}

public async append(message: AppendMessage): Promise<void> {
if (message.parentId !== (this.messages.at(-1)?.id ?? null)) {
if (!this.adapter.edit)
throw new Error(
"Message editing is not enabled, please provide an edit callback to VercelRSCAssistantProvider.",
);
await this.withRunning(this.adapter.edit(message));
} else {
await this.withRunning(this.adapter.append(message));
}
}

public async startRun(parentId: string | null): Promise<void> {
if (!this.adapter.reload)
throw new Error(
"Message reloading is not enabled, please provide a reload callback to VercelRSCAssistantProvider.",
);
await this.withRunning(this.adapter.reload(parentId));
}

cancelRun(): void {
throw new Error("VercelRSCRuntime does not support cancelling runs.");
}

public subscribe(callback: () => void): Unsubscribe {
this._subscriptions.add(callback);
return () => this._subscriptions.delete(callback);
}

public onAdapterUpdated() {
if (this.useAdapter.getState().adapter !== this.adapter) {
this.useAdapter.setState({ adapter: this.adapter });
}
}

private updateData = (messages: ThreadMessage[]) => {
this.messages = messages;
for (const callback of this._subscriptions) callback();
};

unstable_synchronizer = () => {
const { adapter } = this.useAdapter();

useVercelRSCSync(adapter, this.updateData);

return null;
};

addToolResult() {
throw new Error("VercelRSCRuntime does not support adding tool results");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@ import { hasUpcomingMessage } from "./VercelUseAssistantRuntime";

const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);

const CAPABILITIES = Object.freeze({
edit: false,
reload: false,
cancel: false,
copy: true,
});

export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime {
private _subscriptions = new Set<() => void>();

public readonly capabilities = CAPABILITIES;

private useVercel: UseBoundStore<StoreApi<{ vercel: UseAssistantHelpers }>>;

public messages: readonly ThreadMessage[] = [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ export const hasUpcomingMessage = (
return isRunning && messages[messages.length - 1]?.role !== "assistant";
};

const CAPABILITIES = Object.freeze({
edit: true,
reload: true,
cancel: true,
copy: true,
});

export class VercelUseChatThreadRuntime implements ReactThreadRuntime {
private _subscriptions = new Set<() => void>();
private repository = new MessageRepository();
Expand All @@ -31,6 +38,8 @@ export class VercelUseChatThreadRuntime implements ReactThreadRuntime {
StoreApi<{ vercel: ReturnType<typeof useChat> }>
>;

public readonly capabilities = CAPABILITIES;

public messages: ThreadMessage[] = [];
public isRunning = false;

Expand Down
28 changes: 23 additions & 5 deletions packages/react-ui/src/components/assistant-action-bar.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"use client";

import { ActionBarPrimitive, MessagePrimitive } from "@assistant-ui/react";
import {
ActionBarPrimitive,
MessagePrimitive,
useThreadContext,
} from "@assistant-ui/react";
import { forwardRef, type FC } from "react";
import { CheckIcon, CopyIcon, RefreshCwIcon } from "lucide-react";
import {
Expand All @@ -10,9 +14,23 @@ import {
import { styled } from "../styled";
import { useThreadConfig } from "./thread-config";

const useAllowCopy = () => {
const { assistantMessage: { allowCopy = true } = {} } = useThreadConfig();
const { useThreadActions } = useThreadContext();
const copySupported = useThreadActions((t) => t.capabilities.copy);
return copySupported && allowCopy;
};

const useAllowReload = () => {
const { assistantMessage: { allowReload = true } = {} } = useThreadConfig();
const { useThreadActions } = useThreadContext();
const reloadSupported = useThreadActions((t) => t.capabilities.reload);
return reloadSupported && allowReload;
};

export const AssistantActionBar: FC = () => {
const { assistantMessage: { allowCopy = true, allowReload = true } = {} } =
useThreadConfig();
const allowCopy = useAllowCopy();
const allowReload = useAllowReload();
if (!allowCopy && !allowReload) return null;
return (
<AssistantActionBarRoot
Expand All @@ -39,11 +57,11 @@ export const AssistantActionBarCopy = forwardRef<
Partial<TooltipIconButtonProps>
>((props, ref) => {
const {
assistantMessage: { allowCopy = true } = {},
strings: {
assistantMessage: { reload: { tooltip = "Copy" } = {} } = {},
} = {},
} = useThreadConfig();
const allowCopy = useAllowCopy();
if (!allowCopy) return null;
return (
<ActionBarPrimitive.Copy asChild>
Expand All @@ -66,11 +84,11 @@ export const AssistantActionBarReload = forwardRef<
Partial<TooltipIconButtonProps>
>((props, ref) => {
const {
assistantMessage: { allowReload = true } = {},
strings: {
assistantMessage: { reload: { tooltip = "Refresh" } = {} } = {},
} = {},
} = useThreadConfig();
const allowReload = useAllowReload();
if (!allowReload) return null;
return (
<ActionBarPrimitive.Reload asChild>
Expand Down
Loading

0 comments on commit 05fd5d6

Please sign in to comment.