diff --git a/.changeset/silly-chefs-rhyme.md b/.changeset/silly-chefs-rhyme.md
new file mode 100644
index 000000000..54cd17364
--- /dev/null
+++ b/.changeset/silly-chefs-rhyme.md
@@ -0,0 +1,6 @@
+---
+"@assistant-ui/react-hook-form": patch
+"@assistant-ui/react": patch
+---
+
+feat: Tool Render functions
diff --git a/apps/www/pages/docs/primitives/ContentPart.mdx b/apps/www/pages/docs/primitives/ContentPart.mdx
index 99dd50276..6f12398b9 100644
--- a/apps/www/pages/docs/primitives/ContentPart.mdx
+++ b/apps/www/pages/docs/primitives/ContentPart.mdx
@@ -23,12 +23,6 @@ const ImageContentPart = () => {
);
};
-const ToolCallContentPart = () => {
- return (
-
- );
-};
-
const UIContentPart = () => {
return (
@@ -52,4 +46,4 @@ Renders the image content of an image content part as an `` element.
### Display
-Renders the display content of a tool call or UI content part. This feature allows for colocation of tool call definition and corresponding UI elements.
+Renders the display content of a UI content part. This feature is used by the Vercel RSC runtime.
diff --git a/examples/with-react-hook-form/app/page.tsx b/examples/with-react-hook-form/app/page.tsx
index cde60523f..fefa086c4 100644
--- a/examples/with-react-hook-form/app/page.tsx
+++ b/examples/with-react-hook-form/app/page.tsx
@@ -7,6 +7,22 @@ import { useAssistantForm } from "@assistant-ui/react-hook-form";
import { useAssistantInstructions } from "@assistant-ui/react/experimental";
import Link from "next/link";
+const SetFormFieldTool = () => {
+ return (
+
+ set_form_field(...)
+
+ );
+};
+
+const SubmitFormTool = () => {
+ return (
+
+ submit_form(...)
+
+ );
+};
+
export default function Home() {
useAssistantInstructions("Help users sign up for Simon's hackathon.");
const form = useAssistantForm({
@@ -18,6 +34,16 @@ export default function Home() {
projectIdea: "",
proficientTechnologies: "",
},
+ assistant: {
+ tools: {
+ set_form_field: {
+ render: SetFormFieldTool,
+ },
+ submit_form: {
+ render: SubmitFormTool,
+ },
+ },
+ },
});
return (
diff --git a/examples/with-react-hook-form/components/SignupForm.tsx b/examples/with-react-hook-form/components/SignupForm.tsx
index a25a4153f..086840d4e 100644
--- a/examples/with-react-hook-form/components/SignupForm.tsx
+++ b/examples/with-react-hook-form/components/SignupForm.tsx
@@ -22,7 +22,6 @@ export const SignupForm: FC = () => {
const onSubmit = async (values: object) => {
try {
setIsSubmitting(true);
- console.log(values);
await submitSignup(values);
setIsSubmitted(true);
} finally {
diff --git a/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx b/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx
index cd96ab158..50f1cc997 100644
--- a/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx
+++ b/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx
@@ -17,7 +17,6 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
-import type { ToolCallContentPart } from "@assistant-ui/react/experimental";
import { TooltipProvider } from "@radix-ui/react-tooltip";
import {
ArrowDownIcon,
@@ -178,16 +177,7 @@ const AssistantMessage: FC = () => {
-
+
@@ -221,22 +211,6 @@ const AssistantMessage: FC = () => {
);
};
-const SetFormFieldTool: FC<{ part: ToolCallContentPart }> = () => {
- return (
-
- set_form_field(...)
-
- );
-};
-
-const SubmitFormTool: FC<{ part: ToolCallContentPart }> = () => {
- return (
-
- submit_form(...)
-
- );
-};
-
const BranchPicker: FC = () => {
return (
= UseFormProps & {
+ assistant?: {
+ tools?: {
+ set_form_field?: {
+ render?: ToolRenderComponent<
+ z.ZodType,
+ unknown
+ >;
+ };
+ submit_form?: {
+ render?: ToolRenderComponent<
+ z.ZodType,
+ unknown
+ >;
+ };
+ };
+ };
+};
+
export const useAssistantForm = <
TFieldValues extends FieldValues = FieldValues,
// biome-ignore lint/suspicious/noExplicitAny:
TContext = any,
TTransformedValues extends FieldValues | undefined = undefined,
>(
- props?: UseFormProps,
+ props?: UseAssistantFormProps,
): UseFormReturn => {
const form = useForm(props);
@@ -26,53 +54,77 @@ export const useAssistantForm = <
);
useEffect(() => {
- return registerModelConfigProvider(() => {
- return {
- system: `Form State:\n${JSON.stringify(form.getValues())}`,
+ const value: ModelConfig = {
+ system: `Form State:\n${JSON.stringify(form.getValues())}`,
- tools: {
- set_form_field: {
- ...formTools.set_form_field,
- execute: async (args) => {
- // biome-ignore lint/suspicious/noExplicitAny: TODO
- form.setValue(args.name as any, args.value as any);
+ tools: {
+ set_form_field: {
+ ...formTools.set_form_field,
+ execute: async (args) => {
+ // biome-ignore lint/suspicious/noExplicitAny: TODO
+ form.setValue(args.name as any, args.value as any);
- return { success: true };
- },
+ return { success: true };
},
- submit_form: {
- ...formTools.submit_form,
- execute: async () => {
- const { _names, _fields } = form.control;
- for (const name of _names.mount) {
- const field = _fields[name];
- if (field?._f) {
- const fieldReference = Array.isArray(field._f.refs)
- ? field._f.refs[0]
- : field._f.ref;
+ },
+ submit_form: {
+ ...formTools.submit_form,
+ execute: async () => {
+ const { _names, _fields } = form.control;
+ for (const name of _names.mount) {
+ const field = _fields[name];
+ if (field?._f) {
+ const fieldReference = Array.isArray(field._f.refs)
+ ? field._f.refs[0]
+ : field._f.ref;
- if (fieldReference instanceof HTMLElement) {
- const form = fieldReference.closest("form");
- if (form) {
- form.requestSubmit();
+ if (fieldReference instanceof HTMLElement) {
+ const form = fieldReference.closest("form");
+ if (form) {
+ form.requestSubmit();
- return { success: true };
- }
+ return { success: true };
}
}
}
+ }
- return {
- success: false,
- message:
- "Unable retrieve the form element. This is a coding error.",
- };
- },
+ return {
+ success: false,
+ message:
+ "Unable retrieve the form element. This is a coding error.",
+ };
},
},
- };
- });
- }, [form, registerModelConfigProvider]);
+ },
+ };
+ return registerModelConfigProvider(() => value);
+ }, [
+ form.control,
+ form.setValue,
+ form.getValues,
+ registerModelConfigProvider,
+ ]);
+
+ const renderFormFieldTool = props?.assistant?.tools?.set_form_field?.render;
+ useAssistantToolRenderer(
+ renderFormFieldTool
+ ? {
+ name: "set_form_field",
+ render: renderFormFieldTool,
+ }
+ : null,
+ );
+
+ const renderSubmitFormTool = props?.assistant?.tools?.submit_form?.render;
+ useAssistantToolRenderer(
+ renderSubmitFormTool
+ ? {
+ name: "submit_form",
+ render: renderSubmitFormTool,
+ }
+ : null,
+ );
return form;
};
diff --git a/packages/react/src/context/AssistantContext.ts b/packages/react/src/context/AssistantContext.ts
index 39a1da683..0fc15a7de 100644
--- a/packages/react/src/context/AssistantContext.ts
+++ b/packages/react/src/context/AssistantContext.ts
@@ -1,9 +1,11 @@
import { createContext, useContext } from "react";
import type { StoreApi, UseBoundStore } from "zustand";
import type { AssistantModelConfigState } from "./stores/AssistantModelConfig";
+import type { AssistantToolRenderersState } from "./stores/AssistantToolRenderers";
export type AssistantContextValue = {
useModelConfig: UseBoundStore>;
+ useToolRenderers: UseBoundStore>;
};
export const AssistantContext = createContext(
diff --git a/packages/react/src/context/providers/AssistantProvider.tsx b/packages/react/src/context/providers/AssistantProvider.tsx
index 4bd86092a..561efb61e 100644
--- a/packages/react/src/context/providers/AssistantProvider.tsx
+++ b/packages/react/src/context/providers/AssistantProvider.tsx
@@ -3,6 +3,7 @@ import { useEffect, useInsertionEffect, useRef, useState } from "react";
import type { AssistantRuntime } from "../../runtime";
import { AssistantContext } from "../AssistantContext";
import { makeAssistantModelConfigStore } from "../stores/AssistantModelConfig";
+import { makeAssistantToolRenderersStore } from "../stores/AssistantToolRenderers";
import { ThreadProvider } from "./ThreadProvider";
type AssistantProviderProps = {
@@ -19,8 +20,9 @@ export const AssistantProvider: FC<
const [context] = useState(() => {
const useModelConfig = makeAssistantModelConfigStore();
+ const useToolRenderers = makeAssistantToolRenderersStore();
- return { useModelConfig };
+ return { useModelConfig, useToolRenderers };
});
const getModelCOnfig = context.useModelConfig((c) => c.getModelConfig);
diff --git a/packages/react/src/context/stores/AssistantToolRenderers.ts b/packages/react/src/context/stores/AssistantToolRenderers.ts
new file mode 100644
index 000000000..42fc261c8
--- /dev/null
+++ b/packages/react/src/context/stores/AssistantToolRenderers.ts
@@ -0,0 +1,46 @@
+"use client";
+
+import { create } from "zustand";
+import type { ToolRenderComponent } from "../../model-config/ToolRenderComponent";
+
+export type AssistantToolRenderersState = {
+ // biome-ignore lint/suspicious/noExplicitAny: intentional any
+ getToolRenderer: (name: string) => ToolRenderComponent | null;
+ setToolRenderer: (
+ name: string,
+ // biome-ignore lint/suspicious/noExplicitAny: intentional any
+ render: ToolRenderComponent,
+ ) => () => void;
+};
+
+export const makeAssistantToolRenderersStore = () =>
+ create((set) => {
+ // biome-ignore lint/suspicious/noExplicitAny: intentional any
+ const renderers = new Map[]>();
+
+ return {
+ getToolRenderer: (name) => {
+ const arr = renderers.get(name);
+ const last = arr?.at(-1);
+ if (last) return last;
+ return null;
+ },
+ setToolRenderer: (name, render) => {
+ let arr = renderers.get(name);
+ if (!arr) {
+ arr = [];
+ renderers.set(name, arr);
+ }
+ arr.push(render);
+ set({}); // notify the store listeners
+
+ return () => {
+ const index = arr.indexOf(render);
+ if (index !== -1) {
+ arr.splice(index, 1);
+ }
+ set({}); // notify the store listeners
+ };
+ },
+ } satisfies AssistantToolRenderersState;
+ });
diff --git a/packages/react/src/experimental.ts b/packages/react/src/experimental.ts
index 0e9dc53e9..5293ad129 100644
--- a/packages/react/src/experimental.ts
+++ b/packages/react/src/experimental.ts
@@ -4,8 +4,12 @@ export type {
UIContentPart,
} from "./utils/AssistantTypes";
-export type { ModelConfigProvider } from "./utils/ModelConfigTypes";
+export type {
+ ModelConfigProvider,
+ ModelConfig,
+} from "./utils/ModelConfigTypes";
export * from "./context";
export { useAssistantInstructions } from "./model-config/useAssistantInstructions";
export { useAssistantTool } from "./model-config/useAssistantTool";
+export { useAssistantToolRenderer } from "./model-config/useAssistantToolRenderer";
diff --git a/packages/react/src/model-config/ToolRenderComponent.tsx b/packages/react/src/model-config/ToolRenderComponent.tsx
new file mode 100644
index 000000000..1754b015b
--- /dev/null
+++ b/packages/react/src/model-config/ToolRenderComponent.tsx
@@ -0,0 +1,8 @@
+"use client";
+import type { ComponentType } from "react";
+import type { ToolCallContentPart } from "../experimental";
+
+export type ToolRenderComponent = ComponentType<{
+ part: ToolCallContentPart;
+ status: "done" | "in_progress" | "error";
+}>;
diff --git a/packages/react/src/model-config/useAssistantInstructions.tsx b/packages/react/src/model-config/useAssistantInstructions.tsx
index fdc279b30..91743936d 100644
--- a/packages/react/src/model-config/useAssistantInstructions.tsx
+++ b/packages/react/src/model-config/useAssistantInstructions.tsx
@@ -8,13 +8,10 @@ export const useAssistantInstructions = (instruction: string) => {
const registerModelConfigProvider = useModelConfig(
(s) => s.registerModelConfigProvider,
);
- useEffect(
- () =>
- registerModelConfigProvider(() => {
- return {
- system: instruction,
- };
- }),
- [registerModelConfigProvider, instruction],
- );
+ useEffect(() => {
+ const config = {
+ system: instruction,
+ };
+ return registerModelConfigProvider(() => config);
+ }, [registerModelConfigProvider, instruction]);
};
diff --git a/packages/react/src/model-config/useAssistantTool.tsx b/packages/react/src/model-config/useAssistantTool.tsx
index 2ff4cdfdb..317c5ac16 100644
--- a/packages/react/src/model-config/useAssistantTool.tsx
+++ b/packages/react/src/model-config/useAssistantTool.tsx
@@ -2,22 +2,34 @@
import { useEffect } from "react";
import { useAssistantContext } from "../context/AssistantContext";
-import type { ToolWithName } from "../utils/ModelConfigTypes";
+import type { Tool } from "../utils/ModelConfigTypes";
+import type { ToolRenderComponent } from "./ToolRenderComponent";
-export const useAssistantTool = (tool: ToolWithName) => {
- const { useModelConfig } = useAssistantContext();
+export type UseAssistantTool = Tool & {
+ name: string;
+ render?: ToolRenderComponent;
+};
+
+export const useAssistantTool = (
+ tool: UseAssistantTool,
+) => {
+ const { useModelConfig, useToolRenderers } = useAssistantContext();
const registerModelConfigProvider = useModelConfig(
(s) => s.registerModelConfigProvider,
);
- useEffect(
- () =>
- registerModelConfigProvider(() => {
- return {
- tools: {
- [tool.name]: tool,
- },
- };
- }),
- [registerModelConfigProvider, tool],
- );
+ const setToolRenderer = useToolRenderers((s) => s.setToolRenderer);
+ useEffect(() => {
+ const { name, render, ...rest } = tool;
+ const config = {
+ tools: {
+ [tool.name]: rest,
+ },
+ };
+ const unsub1 = registerModelConfigProvider(() => config);
+ const unsub2 = render ? setToolRenderer(name, render) : undefined;
+ return () => {
+ unsub1();
+ unsub2?.();
+ };
+ }, [registerModelConfigProvider, setToolRenderer, tool]);
};
diff --git a/packages/react/src/model-config/useAssistantToolRenderer.tsx b/packages/react/src/model-config/useAssistantToolRenderer.tsx
new file mode 100644
index 000000000..2cb081aa2
--- /dev/null
+++ b/packages/react/src/model-config/useAssistantToolRenderer.tsx
@@ -0,0 +1,22 @@
+"use client";
+import { useEffect } from "react";
+import { useAssistantContext } from "../context/AssistantContext";
+import type { ToolRenderComponent } from "./ToolRenderComponent";
+
+type UseAssistantToolRenderer = {
+ name: string;
+ render: ToolRenderComponent;
+};
+
+export const useAssistantToolRenderer = (
+ // biome-ignore lint/suspicious/noExplicitAny: intentional any
+ tool: UseAssistantToolRenderer | null,
+) => {
+ const { useToolRenderers } = useAssistantContext();
+ const setToolRenderer = useToolRenderers((s) => s.setToolRenderer);
+ useEffect(() => {
+ if (!tool) return;
+ const { name, render } = tool;
+ return setToolRenderer(name, render);
+ }, [setToolRenderer, tool]);
+};
diff --git a/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx b/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx
index db023ecf8..71c413e71 100644
--- a/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx
+++ b/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx
@@ -5,9 +5,9 @@ export const ContentPartDisplay: FC = () => {
const { useContentPart } = useContentPartContext();
const display = useContentPart((c) => {
- if (c.part.type !== "ui" && c.part.type !== "tool-call")
+ if (c.part.type !== "ui")
throw new Error(
- "ContentPartDisplay can only be used inside tool-call or ui content parts.",
+ "ContentPartDisplay can only be used inside ui content parts.",
);
return c.part.display;
diff --git a/packages/react/src/primitives/message/MessageContent.tsx b/packages/react/src/primitives/message/MessageContent.tsx
index 1f438af60..d3671b9bb 100644
--- a/packages/react/src/primitives/message/MessageContent.tsx
+++ b/packages/react/src/primitives/message/MessageContent.tsx
@@ -1,6 +1,7 @@
"use client";
-import { type ComponentType, type FC, type ReactNode, memo } from "react";
+import { type ComponentType, type FC, memo } from "react";
+import { useAssistantContext, useContentPartContext } from "../../context";
import { useMessageContext } from "../../context/MessageContext";
import { ContentPartProvider } from "../../context/providers/ContentPartProvider";
import type {
@@ -15,12 +16,30 @@ import { ContentPartText } from "../contentPart/ContentPartText";
type MessageContentProps = {
components?: {
- Text?: ComponentType<{ part: TextContentPart }>;
- Image?: ComponentType<{ part: ImageContentPart }>;
- UI?: ComponentType<{ part: UIContentPart }>;
+ Text?: ComponentType<{
+ part: TextContentPart;
+ status: "done" | "in_progress" | "error";
+ }>;
+ Image?: ComponentType<{
+ part: ImageContentPart;
+ status: "done" | "in_progress" | "error";
+ }>;
+ UI?: ComponentType<{
+ part: UIContentPart;
+ status: "done" | "in_progress" | "error";
+ }>;
tools?: {
- by_name?: Record>;
- Fallback?: ComponentType<{ part: ToolCallContentPart }>;
+ by_name?: Record<
+ string,
+ ComponentType<{
+ part: ToolCallContentPart;
+ status: "done" | "in_progress" | "error";
+ }>
+ >;
+ Fallback?: ComponentType<{
+ part: ToolCallContentPart;
+ status: "done" | "in_progress" | "error";
+ }>;
};
};
};
@@ -35,17 +54,22 @@ const defaultComponents = {
Image: () => null,
UI: () => ,
tools: {
- Fallback: () => ,
+ Fallback: (props) => {
+ const { useToolRenderers } = useAssistantContext();
+ const Render = useToolRenderers((s) =>
+ s.getToolRenderer(props.part.toolName),
+ );
+ if (!Render) return null;
+ return ;
+ },
},
} satisfies MessageContentProps["components"];
-type MessageContentPartProps = {
- partIndex: number;
+type MessageContentPartComponentProps = {
components: MessageContentProps["components"];
};
-const MessageContentPartImpl: FC = ({
- partIndex,
+const MessageContentPartComponent: FC = ({
components: {
Text = defaultComponents.Text,
Image = defaultComponents.Image,
@@ -53,35 +77,41 @@ const MessageContentPartImpl: FC = ({
tools: { by_name = {}, Fallback = defaultComponents.tools.Fallback } = {},
} = {},
}) => {
- const { useMessage } = useMessageContext();
-
- const part = useMessage((s) => s.message.content[partIndex]!);
+ const { useContentPart } = useContentPartContext();
+ const { part, status } = useContentPart();
const type = part.type;
- let component: ReactNode | null = null;
switch (type) {
case "text":
- component = ;
- break;
+ return ;
case "image":
- component = ;
- break;
+ return ;
+
case "ui":
- component = ;
- break;
+ return ;
+
case "tool-call": {
const Tool = by_name[part.toolName] || Fallback;
- component = ;
- break;
+ return ;
}
default:
throw new Error(`Unknown content part type: ${type}`);
}
+};
+type MessageContentPartProps = {
+ partIndex: number;
+ components: MessageContentProps["components"];
+};
+
+const MessageContentPartImpl: FC = ({
+ partIndex,
+ components,
+}) => {
return (
-
- {component}
+
+
);
};
diff --git a/packages/react/src/utils/AssistantTypes.ts b/packages/react/src/utils/AssistantTypes.ts
index 0ea3f0fab..438e19ee5 100644
--- a/packages/react/src/utils/AssistantTypes.ts
+++ b/packages/react/src/utils/AssistantTypes.ts
@@ -15,13 +15,12 @@ export type UIContentPart = {
display: ReactNode;
};
-export type ToolCallContentPart = {
+export type ToolCallContentPart = {
type: "tool-call";
toolCallId: string;
toolName: string;
- args: object;
- result?: object;
- display?: ReactNode;
+ args: TArgs;
+ result?: TResult;
};
export type UserContentPart =
diff --git a/packages/react/src/utils/ModelConfigTypes.ts b/packages/react/src/utils/ModelConfigTypes.ts
index 4097b9204..adb4b7272 100644
--- a/packages/react/src/utils/ModelConfigTypes.ts
+++ b/packages/react/src/utils/ModelConfigTypes.ts
@@ -1,21 +1,21 @@
"use client";
import type { z } from "zod";
-export type Tool = {
- description: string;
- parameters: z.ZodSchema;
- execute: (args: TArgs) => Promise; // TODO return type
-};
+type ToolExecuteFunction = (
+ args: TArgs,
+) => TResult | Promise;
-export type ToolWithName = Tool & {
- name: string;
+export type Tool = {
+ description?: string;
+ parameters: z.ZodSchema;
+ execute: ToolExecuteFunction;
};
export type ModelConfig = {
priority?: number;
system?: string;
- // biome-ignore lint/suspicious/noExplicitAny: TODO
- tools?: Record>;
+ // biome-ignore lint/suspicious/noExplicitAny: intentional any
+ tools?: Record>;
};
export type ModelConfigProvider = () => ModelConfig;