diff --git a/.changeset/silly-chefs-rhyme.md b/.changeset/silly-chefs-rhyme.md new file mode 100644 index 000000000..54cd17364 --- /dev/null +++ b/.changeset/silly-chefs-rhyme.md @@ -0,0 +1,6 @@ +--- +"@assistant-ui/react-hook-form": patch +"@assistant-ui/react": patch +--- + +feat: Tool Render functions diff --git a/apps/www/pages/docs/primitives/ContentPart.mdx b/apps/www/pages/docs/primitives/ContentPart.mdx index 99dd50276..6f12398b9 100644 --- a/apps/www/pages/docs/primitives/ContentPart.mdx +++ b/apps/www/pages/docs/primitives/ContentPart.mdx @@ -23,12 +23,6 @@ const ImageContentPart = () => { ); }; -const ToolCallContentPart = () => { - return ( - - ); -}; - const UIContentPart = () => { return ( @@ -52,4 +46,4 @@ Renders the image content of an image content part as an `` element. ### Display -Renders the display content of a tool call or UI content part. This feature allows for colocation of tool call definition and corresponding UI elements. +Renders the display content of a UI content part. This feature is used by the Vercel RSC runtime. diff --git a/examples/with-react-hook-form/app/page.tsx b/examples/with-react-hook-form/app/page.tsx index cde60523f..fefa086c4 100644 --- a/examples/with-react-hook-form/app/page.tsx +++ b/examples/with-react-hook-form/app/page.tsx @@ -7,6 +7,22 @@ import { useAssistantForm } from "@assistant-ui/react-hook-form"; import { useAssistantInstructions } from "@assistant-ui/react/experimental"; import Link from "next/link"; +const SetFormFieldTool = () => { + return ( +

+ set_form_field(...) +

+ ); +}; + +const SubmitFormTool = () => { + return ( +

+ submit_form(...) +

+ ); +}; + export default function Home() { useAssistantInstructions("Help users sign up for Simon's hackathon."); const form = useAssistantForm({ @@ -18,6 +34,16 @@ export default function Home() { projectIdea: "", proficientTechnologies: "", }, + assistant: { + tools: { + set_form_field: { + render: SetFormFieldTool, + }, + submit_form: { + render: SubmitFormTool, + }, + }, + }, }); return ( diff --git a/examples/with-react-hook-form/components/SignupForm.tsx b/examples/with-react-hook-form/components/SignupForm.tsx index a25a4153f..086840d4e 100644 --- a/examples/with-react-hook-form/components/SignupForm.tsx +++ b/examples/with-react-hook-form/components/SignupForm.tsx @@ -22,7 +22,6 @@ export const SignupForm: FC = () => { const onSubmit = async (values: object) => { try { setIsSubmitting(true); - console.log(values); await submitSignup(values); setIsSubmitted(true); } finally { diff --git a/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx b/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx index cd96ab158..50f1cc997 100644 --- a/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx +++ b/examples/with-react-hook-form/components/ui/assistant-ui/thread.tsx @@ -17,7 +17,6 @@ import { TooltipTrigger, } from "@/components/ui/tooltip"; import { cn } from "@/lib/utils"; -import type { ToolCallContentPart } from "@assistant-ui/react/experimental"; import { TooltipProvider } from "@radix-ui/react-tooltip"; import { ArrowDownIcon, @@ -178,16 +177,7 @@ const AssistantMessage: FC = () => {
- +
@@ -221,22 +211,6 @@ const AssistantMessage: FC = () => { ); }; -const SetFormFieldTool: FC<{ part: ToolCallContentPart }> = () => { - return ( -

- set_form_field(...) -

- ); -}; - -const SubmitFormTool: FC<{ part: ToolCallContentPart }> = () => { - return ( -

- submit_form(...) -

- ); -}; - const BranchPicker: FC = () => { return ( = UseFormProps & { + assistant?: { + tools?: { + set_form_field?: { + render?: ToolRenderComponent< + z.ZodType, + unknown + >; + }; + submit_form?: { + render?: ToolRenderComponent< + z.ZodType, + unknown + >; + }; + }; + }; +}; + export const useAssistantForm = < TFieldValues extends FieldValues = FieldValues, // biome-ignore lint/suspicious/noExplicitAny: TContext = any, TTransformedValues extends FieldValues | undefined = undefined, >( - props?: UseFormProps, + props?: UseAssistantFormProps, ): UseFormReturn => { const form = useForm(props); @@ -26,53 +54,77 @@ export const useAssistantForm = < ); useEffect(() => { - return registerModelConfigProvider(() => { - return { - system: `Form State:\n${JSON.stringify(form.getValues())}`, + const value: ModelConfig = { + system: `Form State:\n${JSON.stringify(form.getValues())}`, - tools: { - set_form_field: { - ...formTools.set_form_field, - execute: async (args) => { - // biome-ignore lint/suspicious/noExplicitAny: TODO - form.setValue(args.name as any, args.value as any); + tools: { + set_form_field: { + ...formTools.set_form_field, + execute: async (args) => { + // biome-ignore lint/suspicious/noExplicitAny: TODO + form.setValue(args.name as any, args.value as any); - return { success: true }; - }, + return { success: true }; }, - submit_form: { - ...formTools.submit_form, - execute: async () => { - const { _names, _fields } = form.control; - for (const name of _names.mount) { - const field = _fields[name]; - if (field?._f) { - const fieldReference = Array.isArray(field._f.refs) - ? field._f.refs[0] - : field._f.ref; + }, + submit_form: { + ...formTools.submit_form, + execute: async () => { + const { _names, _fields } = form.control; + for (const name of _names.mount) { + const field = _fields[name]; + if (field?._f) { + const fieldReference = Array.isArray(field._f.refs) + ? field._f.refs[0] + : field._f.ref; - if (fieldReference instanceof HTMLElement) { - const form = fieldReference.closest("form"); - if (form) { - form.requestSubmit(); + if (fieldReference instanceof HTMLElement) { + const form = fieldReference.closest("form"); + if (form) { + form.requestSubmit(); - return { success: true }; - } + return { success: true }; } } } + } - return { - success: false, - message: - "Unable retrieve the form element. This is a coding error.", - }; - }, + return { + success: false, + message: + "Unable retrieve the form element. This is a coding error.", + }; }, }, - }; - }); - }, [form, registerModelConfigProvider]); + }, + }; + return registerModelConfigProvider(() => value); + }, [ + form.control, + form.setValue, + form.getValues, + registerModelConfigProvider, + ]); + + const renderFormFieldTool = props?.assistant?.tools?.set_form_field?.render; + useAssistantToolRenderer( + renderFormFieldTool + ? { + name: "set_form_field", + render: renderFormFieldTool, + } + : null, + ); + + const renderSubmitFormTool = props?.assistant?.tools?.submit_form?.render; + useAssistantToolRenderer( + renderSubmitFormTool + ? { + name: "submit_form", + render: renderSubmitFormTool, + } + : null, + ); return form; }; diff --git a/packages/react/src/context/AssistantContext.ts b/packages/react/src/context/AssistantContext.ts index 39a1da683..0fc15a7de 100644 --- a/packages/react/src/context/AssistantContext.ts +++ b/packages/react/src/context/AssistantContext.ts @@ -1,9 +1,11 @@ import { createContext, useContext } from "react"; import type { StoreApi, UseBoundStore } from "zustand"; import type { AssistantModelConfigState } from "./stores/AssistantModelConfig"; +import type { AssistantToolRenderersState } from "./stores/AssistantToolRenderers"; export type AssistantContextValue = { useModelConfig: UseBoundStore>; + useToolRenderers: UseBoundStore>; }; export const AssistantContext = createContext( diff --git a/packages/react/src/context/providers/AssistantProvider.tsx b/packages/react/src/context/providers/AssistantProvider.tsx index 4bd86092a..561efb61e 100644 --- a/packages/react/src/context/providers/AssistantProvider.tsx +++ b/packages/react/src/context/providers/AssistantProvider.tsx @@ -3,6 +3,7 @@ import { useEffect, useInsertionEffect, useRef, useState } from "react"; import type { AssistantRuntime } from "../../runtime"; import { AssistantContext } from "../AssistantContext"; import { makeAssistantModelConfigStore } from "../stores/AssistantModelConfig"; +import { makeAssistantToolRenderersStore } from "../stores/AssistantToolRenderers"; import { ThreadProvider } from "./ThreadProvider"; type AssistantProviderProps = { @@ -19,8 +20,9 @@ export const AssistantProvider: FC< const [context] = useState(() => { const useModelConfig = makeAssistantModelConfigStore(); + const useToolRenderers = makeAssistantToolRenderersStore(); - return { useModelConfig }; + return { useModelConfig, useToolRenderers }; }); const getModelCOnfig = context.useModelConfig((c) => c.getModelConfig); diff --git a/packages/react/src/context/stores/AssistantToolRenderers.ts b/packages/react/src/context/stores/AssistantToolRenderers.ts new file mode 100644 index 000000000..42fc261c8 --- /dev/null +++ b/packages/react/src/context/stores/AssistantToolRenderers.ts @@ -0,0 +1,46 @@ +"use client"; + +import { create } from "zustand"; +import type { ToolRenderComponent } from "../../model-config/ToolRenderComponent"; + +export type AssistantToolRenderersState = { + // biome-ignore lint/suspicious/noExplicitAny: intentional any + getToolRenderer: (name: string) => ToolRenderComponent | null; + setToolRenderer: ( + name: string, + // biome-ignore lint/suspicious/noExplicitAny: intentional any + render: ToolRenderComponent, + ) => () => void; +}; + +export const makeAssistantToolRenderersStore = () => + create((set) => { + // biome-ignore lint/suspicious/noExplicitAny: intentional any + const renderers = new Map[]>(); + + return { + getToolRenderer: (name) => { + const arr = renderers.get(name); + const last = arr?.at(-1); + if (last) return last; + return null; + }, + setToolRenderer: (name, render) => { + let arr = renderers.get(name); + if (!arr) { + arr = []; + renderers.set(name, arr); + } + arr.push(render); + set({}); // notify the store listeners + + return () => { + const index = arr.indexOf(render); + if (index !== -1) { + arr.splice(index, 1); + } + set({}); // notify the store listeners + }; + }, + } satisfies AssistantToolRenderersState; + }); diff --git a/packages/react/src/experimental.ts b/packages/react/src/experimental.ts index 0e9dc53e9..5293ad129 100644 --- a/packages/react/src/experimental.ts +++ b/packages/react/src/experimental.ts @@ -4,8 +4,12 @@ export type { UIContentPart, } from "./utils/AssistantTypes"; -export type { ModelConfigProvider } from "./utils/ModelConfigTypes"; +export type { + ModelConfigProvider, + ModelConfig, +} from "./utils/ModelConfigTypes"; export * from "./context"; export { useAssistantInstructions } from "./model-config/useAssistantInstructions"; export { useAssistantTool } from "./model-config/useAssistantTool"; +export { useAssistantToolRenderer } from "./model-config/useAssistantToolRenderer"; diff --git a/packages/react/src/model-config/ToolRenderComponent.tsx b/packages/react/src/model-config/ToolRenderComponent.tsx new file mode 100644 index 000000000..1754b015b --- /dev/null +++ b/packages/react/src/model-config/ToolRenderComponent.tsx @@ -0,0 +1,8 @@ +"use client"; +import type { ComponentType } from "react"; +import type { ToolCallContentPart } from "../experimental"; + +export type ToolRenderComponent = ComponentType<{ + part: ToolCallContentPart; + status: "done" | "in_progress" | "error"; +}>; diff --git a/packages/react/src/model-config/useAssistantInstructions.tsx b/packages/react/src/model-config/useAssistantInstructions.tsx index fdc279b30..91743936d 100644 --- a/packages/react/src/model-config/useAssistantInstructions.tsx +++ b/packages/react/src/model-config/useAssistantInstructions.tsx @@ -8,13 +8,10 @@ export const useAssistantInstructions = (instruction: string) => { const registerModelConfigProvider = useModelConfig( (s) => s.registerModelConfigProvider, ); - useEffect( - () => - registerModelConfigProvider(() => { - return { - system: instruction, - }; - }), - [registerModelConfigProvider, instruction], - ); + useEffect(() => { + const config = { + system: instruction, + }; + return registerModelConfigProvider(() => config); + }, [registerModelConfigProvider, instruction]); }; diff --git a/packages/react/src/model-config/useAssistantTool.tsx b/packages/react/src/model-config/useAssistantTool.tsx index 2ff4cdfdb..317c5ac16 100644 --- a/packages/react/src/model-config/useAssistantTool.tsx +++ b/packages/react/src/model-config/useAssistantTool.tsx @@ -2,22 +2,34 @@ import { useEffect } from "react"; import { useAssistantContext } from "../context/AssistantContext"; -import type { ToolWithName } from "../utils/ModelConfigTypes"; +import type { Tool } from "../utils/ModelConfigTypes"; +import type { ToolRenderComponent } from "./ToolRenderComponent"; -export const useAssistantTool = (tool: ToolWithName) => { - const { useModelConfig } = useAssistantContext(); +export type UseAssistantTool = Tool & { + name: string; + render?: ToolRenderComponent; +}; + +export const useAssistantTool = ( + tool: UseAssistantTool, +) => { + const { useModelConfig, useToolRenderers } = useAssistantContext(); const registerModelConfigProvider = useModelConfig( (s) => s.registerModelConfigProvider, ); - useEffect( - () => - registerModelConfigProvider(() => { - return { - tools: { - [tool.name]: tool, - }, - }; - }), - [registerModelConfigProvider, tool], - ); + const setToolRenderer = useToolRenderers((s) => s.setToolRenderer); + useEffect(() => { + const { name, render, ...rest } = tool; + const config = { + tools: { + [tool.name]: rest, + }, + }; + const unsub1 = registerModelConfigProvider(() => config); + const unsub2 = render ? setToolRenderer(name, render) : undefined; + return () => { + unsub1(); + unsub2?.(); + }; + }, [registerModelConfigProvider, setToolRenderer, tool]); }; diff --git a/packages/react/src/model-config/useAssistantToolRenderer.tsx b/packages/react/src/model-config/useAssistantToolRenderer.tsx new file mode 100644 index 000000000..2cb081aa2 --- /dev/null +++ b/packages/react/src/model-config/useAssistantToolRenderer.tsx @@ -0,0 +1,22 @@ +"use client"; +import { useEffect } from "react"; +import { useAssistantContext } from "../context/AssistantContext"; +import type { ToolRenderComponent } from "./ToolRenderComponent"; + +type UseAssistantToolRenderer = { + name: string; + render: ToolRenderComponent; +}; + +export const useAssistantToolRenderer = ( + // biome-ignore lint/suspicious/noExplicitAny: intentional any + tool: UseAssistantToolRenderer | null, +) => { + const { useToolRenderers } = useAssistantContext(); + const setToolRenderer = useToolRenderers((s) => s.setToolRenderer); + useEffect(() => { + if (!tool) return; + const { name, render } = tool; + return setToolRenderer(name, render); + }, [setToolRenderer, tool]); +}; diff --git a/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx b/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx index db023ecf8..71c413e71 100644 --- a/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx +++ b/packages/react/src/primitives/contentPart/ContentPartDisplay.tsx @@ -5,9 +5,9 @@ export const ContentPartDisplay: FC = () => { const { useContentPart } = useContentPartContext(); const display = useContentPart((c) => { - if (c.part.type !== "ui" && c.part.type !== "tool-call") + if (c.part.type !== "ui") throw new Error( - "ContentPartDisplay can only be used inside tool-call or ui content parts.", + "ContentPartDisplay can only be used inside ui content parts.", ); return c.part.display; diff --git a/packages/react/src/primitives/message/MessageContent.tsx b/packages/react/src/primitives/message/MessageContent.tsx index 1f438af60..d3671b9bb 100644 --- a/packages/react/src/primitives/message/MessageContent.tsx +++ b/packages/react/src/primitives/message/MessageContent.tsx @@ -1,6 +1,7 @@ "use client"; -import { type ComponentType, type FC, type ReactNode, memo } from "react"; +import { type ComponentType, type FC, memo } from "react"; +import { useAssistantContext, useContentPartContext } from "../../context"; import { useMessageContext } from "../../context/MessageContext"; import { ContentPartProvider } from "../../context/providers/ContentPartProvider"; import type { @@ -15,12 +16,30 @@ import { ContentPartText } from "../contentPart/ContentPartText"; type MessageContentProps = { components?: { - Text?: ComponentType<{ part: TextContentPart }>; - Image?: ComponentType<{ part: ImageContentPart }>; - UI?: ComponentType<{ part: UIContentPart }>; + Text?: ComponentType<{ + part: TextContentPart; + status: "done" | "in_progress" | "error"; + }>; + Image?: ComponentType<{ + part: ImageContentPart; + status: "done" | "in_progress" | "error"; + }>; + UI?: ComponentType<{ + part: UIContentPart; + status: "done" | "in_progress" | "error"; + }>; tools?: { - by_name?: Record>; - Fallback?: ComponentType<{ part: ToolCallContentPart }>; + by_name?: Record< + string, + ComponentType<{ + part: ToolCallContentPart; + status: "done" | "in_progress" | "error"; + }> + >; + Fallback?: ComponentType<{ + part: ToolCallContentPart; + status: "done" | "in_progress" | "error"; + }>; }; }; }; @@ -35,17 +54,22 @@ const defaultComponents = { Image: () => null, UI: () => , tools: { - Fallback: () => , + Fallback: (props) => { + const { useToolRenderers } = useAssistantContext(); + const Render = useToolRenderers((s) => + s.getToolRenderer(props.part.toolName), + ); + if (!Render) return null; + return ; + }, }, } satisfies MessageContentProps["components"]; -type MessageContentPartProps = { - partIndex: number; +type MessageContentPartComponentProps = { components: MessageContentProps["components"]; }; -const MessageContentPartImpl: FC = ({ - partIndex, +const MessageContentPartComponent: FC = ({ components: { Text = defaultComponents.Text, Image = defaultComponents.Image, @@ -53,35 +77,41 @@ const MessageContentPartImpl: FC = ({ tools: { by_name = {}, Fallback = defaultComponents.tools.Fallback } = {}, } = {}, }) => { - const { useMessage } = useMessageContext(); - - const part = useMessage((s) => s.message.content[partIndex]!); + const { useContentPart } = useContentPartContext(); + const { part, status } = useContentPart(); const type = part.type; - let component: ReactNode | null = null; switch (type) { case "text": - component = ; - break; + return ; case "image": - component = ; - break; + return ; + case "ui": - component = ; - break; + return ; + case "tool-call": { const Tool = by_name[part.toolName] || Fallback; - component = ; - break; + return ; } default: throw new Error(`Unknown content part type: ${type}`); } +}; +type MessageContentPartProps = { + partIndex: number; + components: MessageContentProps["components"]; +}; + +const MessageContentPartImpl: FC = ({ + partIndex, + components, +}) => { return ( - - {component} + + ); }; diff --git a/packages/react/src/utils/AssistantTypes.ts b/packages/react/src/utils/AssistantTypes.ts index 0ea3f0fab..438e19ee5 100644 --- a/packages/react/src/utils/AssistantTypes.ts +++ b/packages/react/src/utils/AssistantTypes.ts @@ -15,13 +15,12 @@ export type UIContentPart = { display: ReactNode; }; -export type ToolCallContentPart = { +export type ToolCallContentPart = { type: "tool-call"; toolCallId: string; toolName: string; - args: object; - result?: object; - display?: ReactNode; + args: TArgs; + result?: TResult; }; export type UserContentPart = diff --git a/packages/react/src/utils/ModelConfigTypes.ts b/packages/react/src/utils/ModelConfigTypes.ts index 4097b9204..adb4b7272 100644 --- a/packages/react/src/utils/ModelConfigTypes.ts +++ b/packages/react/src/utils/ModelConfigTypes.ts @@ -1,21 +1,21 @@ "use client"; import type { z } from "zod"; -export type Tool = { - description: string; - parameters: z.ZodSchema; - execute: (args: TArgs) => Promise; // TODO return type -}; +type ToolExecuteFunction = ( + args: TArgs, +) => TResult | Promise; -export type ToolWithName = Tool & { - name: string; +export type Tool = { + description?: string; + parameters: z.ZodSchema; + execute: ToolExecuteFunction; }; export type ModelConfig = { priority?: number; system?: string; - // biome-ignore lint/suspicious/noExplicitAny: TODO - tools?: Record>; + // biome-ignore lint/suspicious/noExplicitAny: intentional any + tools?: Record>; }; export type ModelConfigProvider = () => ModelConfig;