Skip to content

Commit

Permalink
feat: useAssistantActions (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jun 27, 2024
1 parent 62e9f19 commit 611fdcc
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .changeset/eleven-crews-cross.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

feat: useAssistantActions
21 changes: 20 additions & 1 deletion apps/www/components/docs/parameters/context.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import { ParametersTableProps } from "../ParametersTable";

export const AssistantActionsState: ParametersTableProps = {
type: "AssistantActionsState",
parameters: [
{
name: "switchToThread",
type: "(threadId: string | null) => void",
description: "Switch to a new thread.",
required: true,
},
],
};

export const AssistantModelConfigState: ParametersTableProps = {
type: "AssistantModelConfigState",
parameters: [
Expand Down Expand Up @@ -101,6 +113,12 @@ export const AssistantToolUIsState: ParametersTableProps = {
export const AssistantContextValue: ParametersTableProps = {
type: "AssistantContextValue",
parameters: [
{
name: "useAssistantActions",
type: "ReadonlyStore<AssistantActionsState>",
required: true,
description: "Provides functions to perform actions on the assistant.",
},
{
name: "useModelConfig",
type: "ReadonlyStore<AssistantModelConfigState>",
Expand Down Expand Up @@ -215,7 +233,8 @@ export const ComposerState: ParametersTableProps = {
name: "cancel",
type: "() => boolean",
required: true,
description: "A function to cancel the run. Returns true if the run was canceled.",
description:
"A function to cancel the run. Returns true if the run was canceled.",
},
{
name: "focus",
Expand Down
14 changes: 13 additions & 1 deletion apps/www/pages/reference/context.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The context is split into four hierarchies:
import { ParametersTable } from "@/components/docs";
import {
AssistantContextValue,
AssistantModelConfigState,
AssistantActionsState,
AssistantModelConfigState,
AssistantToolUIsState,
ThreadState,
ThreadActionsState,
Expand Down Expand Up @@ -43,6 +44,17 @@ const { useModelConfig, useToolUIs } = useAssistantContext();

<ParametersTable {...AssistantContextValue} />

#### `useAssistantActions`

```tsx
import { useAssistantActions } from "@assistant-ui/react";

const switchToNewThread = useAssistantActions(m => m.switchToThread);
const switchToNewThread = useAssistantActions.getState().switchToThread;
```

<ParametersTable {...AssistantActionsState} />

### `useModelConfig`

```tsx
Expand Down
12 changes: 7 additions & 5 deletions packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ export class VercelRSCRuntime<
return () => {};
}

public newThread() {
this.thread = new VercelRSCThreadRuntime(this.thread.adapter);
}
public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error(
"VercelRSCRuntime does not yet support switching threads",
);
}

public switchToThread() {
throw new Error("VercelRSCRuntime does not yet support switching threads");
this.thread = new VercelRSCThreadRuntime(this.thread.adapter);
}
}

Expand Down
10 changes: 5 additions & 5 deletions packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ export class VercelAIRuntime extends BaseAssistantRuntime<VercelAIThreadRuntime>
return () => {};
}

public newThread() {
this.thread = new VercelAIThreadRuntime(this.thread.vercel);
}
public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error("VercelAIRuntime does not yet support switching threads");
}

public switchToThread() {
throw new Error("VercelAIRuntime does not yet support switching threads");
this.thread = new VercelAIThreadRuntime(this.thread.vercel);
}
}

Expand Down
4 changes: 3 additions & 1 deletion packages/react/src/context/providers/AssistantProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AssistantContext } from "../react/AssistantContext";
import { makeAssistantModelConfigStore } from "../stores/AssistantModelConfig";
import { makeAssistantToolUIsStore } from "../stores/AssistantToolUIs";
import { ThreadProvider } from "./ThreadProvider";
import { makeAssistantActionsStore } from "../stores/AssistantActions";

type AssistantProviderProps = {
runtime: AssistantRuntime;
Expand All @@ -21,8 +22,9 @@ export const AssistantProvider: FC<
const [context] = useState(() => {
const useModelConfig = makeAssistantModelConfigStore();
const useToolUIs = makeAssistantToolUIsStore();
const useAssistantActions = makeAssistantActionsStore(runtimeRef);

return { useModelConfig, useToolUIs };
return { useModelConfig, useToolUIs, useAssistantActions };
});

const getModelCOnfig = context.useModelConfig((c) => c.getModelConfig);
Expand Down
2 changes: 2 additions & 0 deletions packages/react/src/context/react/AssistantContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import { createContext, useContext } from "react";
import type { AssistantModelConfigState } from "../stores/AssistantModelConfig";
import type { AssistantToolUIsState } from "../stores/AssistantToolUIs";
import { ReadonlyStore } from "../ReadonlyStore";
import { AssistantActionsState } from "../stores/AssistantActions";

export type AssistantContextValue = {
useModelConfig: ReadonlyStore<AssistantModelConfigState>;
useToolUIs: ReadonlyStore<AssistantToolUIsState>;
useAssistantActions: ReadonlyStore<AssistantActionsState>;
};

export const AssistantContext = createContext<AssistantContextValue | null>(
Expand Down
16 changes: 16 additions & 0 deletions packages/react/src/context/stores/AssistantActions.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { create } from "zustand";
import { AssistantRuntime } from "../../runtime";
import { MutableRefObject } from "react";

export type AssistantActionsState = Readonly<{
switchToThread: (threadId: string | null) => void;
}>;

export const makeAssistantActionsStore = (
runtimeRef: MutableRefObject<AssistantRuntime>,
) =>
create<AssistantActionsState>(() =>
Object.freeze({
switchToThread: () => runtimeRef.current.switchToThread(null),
}),
);
3 changes: 2 additions & 1 deletion packages/react/src/context/stores/ThreadActions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type { MutableRefObject } from "react";
import { create } from "zustand";
import type { AppendMessage } from "../../types/AssistantTypes";
import { ThreadRuntime } from "../../runtime";

export type ThreadActionsState = Readonly<{
getBranches: (messageId: string) => readonly string[];
Expand All @@ -14,7 +15,7 @@ export type ThreadActionsState = Readonly<{
}>;

export const makeThreadActionStore = (
runtimeRef: MutableRefObject<ThreadActionsState>,
runtimeRef: MutableRefObject<ThreadRuntime>,
) => {
return create<ThreadActionsState>(() =>
Object.freeze({
Expand Down
4 changes: 1 addition & 3 deletions packages/react/src/runtime/core/AssistantRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import type { Unsubscribe } from "../../types/Unsubscribe";
import type { ThreadRuntime } from "./ThreadRuntime";

export type AssistantRuntime = ThreadRuntime & {
newThread: () => void;
switchToThread: (threadId: string) => void;

switchToThread: (threadId: string | null) => void;
registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe;
};
3 changes: 1 addition & 2 deletions packages/react/src/runtime/core/BaseAssistantRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ export abstract class BaseAssistantRuntime<TThreadRuntime extends ThreadRuntime>
public abstract registerModelConfigProvider(
provider: ModelConfigProvider,
): Unsubscribe;
public abstract newThread(): void;
public abstract switchToThread(threadId: string): void;
public abstract switchToThread(threadId: string | null): void;

public get messages() {
return this.thread.messages;
Expand Down
10 changes: 5 additions & 5 deletions packages/react/src/runtime/local/LocalRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ export class LocalRuntime extends BaseAssistantRuntime<LocalThreadRuntime> {
return () => this._configProviders.delete(provider);
}

public newThread() {
public switchToThread(threadId: string | null) {
if (threadId) {
throw new Error("LocalRuntime does not yet support switching threads");
}

return (this.thread = new LocalThreadRuntime(
this._configProviders,
this.thread.adapter,
));
}

public switchToThread() {
throw new Error("LocalRuntime does not yet support switching threads");
}
}

class LocalThreadRuntime implements ThreadRuntime {
Expand Down

0 comments on commit 611fdcc

Please sign in to comment.