Skip to content

Commit

Permalink
feat: Tool render functions (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jun 19, 2024
1 parent 643fd71 commit 671dc86
Show file tree
Hide file tree
Showing 18 changed files with 311 additions and 138 deletions.
6 changes: 6 additions & 0 deletions .changeset/silly-chefs-rhyme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-hook-form": patch
"@assistant-ui/react": patch
---

feat: Tool Render functions
8 changes: 1 addition & 7 deletions apps/www/pages/docs/primitives/ContentPart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ const ImageContentPart = () => {
);
};

const ToolCallContentPart = () => {
return (
<ContentPartPrimitive.Display />
);
};

const UIContentPart = () => {
return (
<ContentPartPrimitive.Display />
Expand All @@ -52,4 +46,4 @@ Renders the image content of an image content part as an `<img>` element.

### Display

Renders the display content of a tool call or UI content part. This feature allows for colocation of tool call definition and corresponding UI elements.
Renders the display content of a UI content part. This feature is used by the Vercel RSC runtime.
26 changes: 26 additions & 0 deletions examples/with-react-hook-form/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,22 @@ import { useAssistantForm } from "@assistant-ui/react-hook-form";
import { useAssistantInstructions } from "@assistant-ui/react/experimental";
import Link from "next/link";

const SetFormFieldTool = () => {
return (
<p className="text-center font-bold font-mono text-blue-500 text-sm">
set_form_field(...)
</p>
);
};

const SubmitFormTool = () => {
return (
<p className="text-center font-bold font-mono text-blue-500 text-sm">
submit_form(...)
</p>
);
};

export default function Home() {
useAssistantInstructions("Help users sign up for Simon's hackathon.");
const form = useAssistantForm({
Expand All @@ -18,6 +34,16 @@ export default function Home() {
projectIdea: "",
proficientTechnologies: "",
},
assistant: {
tools: {
set_form_field: {
render: SetFormFieldTool,
},
submit_form: {
render: SubmitFormTool,
},
},
},
});

return (
Expand Down
1 change: 0 additions & 1 deletion examples/with-react-hook-form/components/SignupForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export const SignupForm: FC = () => {
const onSubmit = async (values: object) => {
try {
setIsSubmitting(true);
console.log(values);
await submitSignup(values);
setIsSubmitted(true);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
import type { ToolCallContentPart } from "@assistant-ui/react/experimental";
import { TooltipProvider } from "@radix-ui/react-tooltip";
import {
ArrowDownIcon,
Expand Down Expand Up @@ -178,16 +177,7 @@ const AssistantMessage: FC = () => {

<MessagePrimitive.InProgress className="inline-block size-3 animate-pulse rounded-full bg-foreground" />
<div className="flex flex-col gap-3 whitespace-pre-line text-foreground">
<MessagePrimitive.Content
components={{
tools: {
by_name: {
set_form_field: SetFormFieldTool,
submit_form: SubmitFormTool,
},
},
}}
/>
<MessagePrimitive.Content />
</div>

<div className="flex pt-2">
Expand Down Expand Up @@ -221,22 +211,6 @@ const AssistantMessage: FC = () => {
);
};

const SetFormFieldTool: FC<{ part: ToolCallContentPart }> = () => {
return (
<p className="text-center font-bold font-mono text-blue-500 text-sm">
set_form_field(...)
</p>
);
};

const SubmitFormTool: FC<{ part: ToolCallContentPart }> = () => {
return (
<p className="text-center font-bold font-mono text-blue-500 text-sm">
submit_form(...)
</p>
);
};

const BranchPicker: FC = () => {
return (
<BranchPickerPrimitive.Root
Expand Down
128 changes: 90 additions & 38 deletions packages/react-hook-form/src/useAssistantForm.tsx
Original file line number Diff line number Diff line change
@@ -1,22 +1,50 @@
"use client";

import { useAssistantContext } from "@assistant-ui/react/experimental";
import {
type ModelConfig,
useAssistantContext,
useAssistantToolRenderer,
} from "@assistant-ui/react/experimental";
import { useEffect } from "react";
import {
type FieldValues,
type UseFormProps,
type UseFormReturn,
useForm,
} from "react-hook-form";
import type { z } from "zod";
import type { ToolRenderComponent } from "../../react/src/model-config/ToolRenderComponent";
import { formTools } from "./formTools";

type UseAssistantFormProps<
TFieldValues extends FieldValues,
TContext,
> = UseFormProps<TFieldValues, TContext> & {
assistant?: {
tools?: {
set_form_field?: {
render?: ToolRenderComponent<
z.ZodType<typeof formTools.set_form_field>,
unknown
>;
};
submit_form?: {
render?: ToolRenderComponent<
z.ZodType<typeof formTools.submit_form>,
unknown
>;
};
};
};
};

export const useAssistantForm = <
TFieldValues extends FieldValues = FieldValues,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
TContext = any,
TTransformedValues extends FieldValues | undefined = undefined,
>(
props?: UseFormProps<TFieldValues, TContext>,
props?: UseAssistantFormProps<TFieldValues, TContext>,
): UseFormReturn<TFieldValues, TContext, TTransformedValues> => {
const form = useForm<TFieldValues, TContext, TTransformedValues>(props);

Expand All @@ -26,53 +54,77 @@ export const useAssistantForm = <
);

useEffect(() => {
return registerModelConfigProvider(() => {
return {
system: `Form State:\n${JSON.stringify(form.getValues())}`,
const value: ModelConfig = {
system: `Form State:\n${JSON.stringify(form.getValues())}`,

tools: {
set_form_field: {
...formTools.set_form_field,
execute: async (args) => {
// biome-ignore lint/suspicious/noExplicitAny: TODO
form.setValue(args.name as any, args.value as any);
tools: {
set_form_field: {
...formTools.set_form_field,
execute: async (args) => {
// biome-ignore lint/suspicious/noExplicitAny: TODO
form.setValue(args.name as any, args.value as any);

return { success: true };
},
return { success: true };
},
submit_form: {
...formTools.submit_form,
execute: async () => {
const { _names, _fields } = form.control;
for (const name of _names.mount) {
const field = _fields[name];
if (field?._f) {
const fieldReference = Array.isArray(field._f.refs)
? field._f.refs[0]
: field._f.ref;
},
submit_form: {
...formTools.submit_form,
execute: async () => {
const { _names, _fields } = form.control;
for (const name of _names.mount) {
const field = _fields[name];
if (field?._f) {
const fieldReference = Array.isArray(field._f.refs)
? field._f.refs[0]
: field._f.ref;

if (fieldReference instanceof HTMLElement) {
const form = fieldReference.closest("form");
if (form) {
form.requestSubmit();
if (fieldReference instanceof HTMLElement) {
const form = fieldReference.closest("form");
if (form) {
form.requestSubmit();

return { success: true };
}
return { success: true };
}
}
}
}

return {
success: false,
message:
"Unable retrieve the form element. This is a coding error.",
};
},
return {
success: false,
message:
"Unable retrieve the form element. This is a coding error.",
};
},
},
};
});
}, [form, registerModelConfigProvider]);
},
};
return registerModelConfigProvider(() => value);
}, [
form.control,
form.setValue,
form.getValues,
registerModelConfigProvider,
]);

const renderFormFieldTool = props?.assistant?.tools?.set_form_field?.render;
useAssistantToolRenderer(
renderFormFieldTool
? {
name: "set_form_field",
render: renderFormFieldTool,
}
: null,
);

const renderSubmitFormTool = props?.assistant?.tools?.submit_form?.render;
useAssistantToolRenderer(
renderSubmitFormTool
? {
name: "submit_form",
render: renderSubmitFormTool,
}
: null,
);

return form;
};
2 changes: 2 additions & 0 deletions packages/react/src/context/AssistantContext.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { createContext, useContext } from "react";
import type { StoreApi, UseBoundStore } from "zustand";
import type { AssistantModelConfigState } from "./stores/AssistantModelConfig";
import type { AssistantToolRenderersState } from "./stores/AssistantToolRenderers";

export type AssistantContextValue = {
useModelConfig: UseBoundStore<StoreApi<AssistantModelConfigState>>;
useToolRenderers: UseBoundStore<StoreApi<AssistantToolRenderersState>>;
};

export const AssistantContext = createContext<AssistantContextValue | null>(
Expand Down
4 changes: 3 additions & 1 deletion packages/react/src/context/providers/AssistantProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { useEffect, useInsertionEffect, useRef, useState } from "react";
import type { AssistantRuntime } from "../../runtime";
import { AssistantContext } from "../AssistantContext";
import { makeAssistantModelConfigStore } from "../stores/AssistantModelConfig";
import { makeAssistantToolRenderersStore } from "../stores/AssistantToolRenderers";
import { ThreadProvider } from "./ThreadProvider";

type AssistantProviderProps = {
Expand All @@ -19,8 +20,9 @@ export const AssistantProvider: FC<

const [context] = useState(() => {
const useModelConfig = makeAssistantModelConfigStore();
const useToolRenderers = makeAssistantToolRenderersStore();

return { useModelConfig };
return { useModelConfig, useToolRenderers };
});

const getModelCOnfig = context.useModelConfig((c) => c.getModelConfig);
Expand Down
46 changes: 46 additions & 0 deletions packages/react/src/context/stores/AssistantToolRenderers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"use client";

import { create } from "zustand";
import type { ToolRenderComponent } from "../../model-config/ToolRenderComponent";

export type AssistantToolRenderersState = {
// biome-ignore lint/suspicious/noExplicitAny: intentional any
getToolRenderer: (name: string) => ToolRenderComponent<any, any> | null;
setToolRenderer: (
name: string,
// biome-ignore lint/suspicious/noExplicitAny: intentional any
render: ToolRenderComponent<any, any>,
) => () => void;
};

export const makeAssistantToolRenderersStore = () =>
create<AssistantToolRenderersState>((set) => {
// biome-ignore lint/suspicious/noExplicitAny: intentional any
const renderers = new Map<string, ToolRenderComponent<any, any>[]>();

return {
getToolRenderer: (name) => {
const arr = renderers.get(name);
const last = arr?.at(-1);
if (last) return last;
return null;
},
setToolRenderer: (name, render) => {
let arr = renderers.get(name);
if (!arr) {
arr = [];
renderers.set(name, arr);
}
arr.push(render);
set({}); // notify the store listeners

return () => {
const index = arr.indexOf(render);
if (index !== -1) {
arr.splice(index, 1);
}
set({}); // notify the store listeners
};
},
} satisfies AssistantToolRenderersState;
});
6 changes: 5 additions & 1 deletion packages/react/src/experimental.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ export type {
UIContentPart,
} from "./utils/AssistantTypes";

export type { ModelConfigProvider } from "./utils/ModelConfigTypes";
export type {
ModelConfigProvider,
ModelConfig,
} from "./utils/ModelConfigTypes";

export * from "./context";
export { useAssistantInstructions } from "./model-config/useAssistantInstructions";
export { useAssistantTool } from "./model-config/useAssistantTool";
export { useAssistantToolRenderer } from "./model-config/useAssistantToolRenderer";
8 changes: 8 additions & 0 deletions packages/react/src/model-config/ToolRenderComponent.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"use client";
import type { ComponentType } from "react";
import type { ToolCallContentPart } from "../experimental";

export type ToolRenderComponent<TArgs, TResult> = ComponentType<{
part: ToolCallContentPart<TArgs, TResult>;
status: "done" | "in_progress" | "error";
}>;
Loading

0 comments on commit 671dc86

Please sign in to comment.