diff --git a/apps/docs/components/claude/Claude.tsx b/apps/docs/components/claude/Claude.tsx index 1c7e40eac..0e90f7cbf 100644 --- a/apps/docs/components/claude/Claude.tsx +++ b/apps/docs/components/claude/Claude.tsx @@ -6,8 +6,8 @@ import { ComposerPrimitive, MessagePrimitive, ThreadPrimitive, + useMessage, } from "@assistant-ui/react"; -import { useMessageContext } from "@assistant-ui/react"; import * as Avatar from "@radix-ui/react-avatar"; import { ArrowUpIcon, ClipboardIcon, ReloadIcon } from "@radix-ui/react-icons"; import type { FC } from "react"; @@ -51,7 +51,6 @@ export const Claude: FC = () => { }; const ChatMessage: FC = () => { - const { useMessage } = useMessageContext(); const { message } = useMessage(); return ( diff --git a/apps/docs/components/docs-chat/DocsChat.tsx b/apps/docs/components/docs-chat/DocsChat.tsx index 8c53ab49a..808749fa9 100644 --- a/apps/docs/components/docs-chat/DocsChat.tsx +++ b/apps/docs/components/docs-chat/DocsChat.tsx @@ -15,8 +15,7 @@ import { Thread, type ThreadConfig } from "@assistant-ui/react"; import entelligenceLogoLight from "./entelligence-light.png"; import entelligenceLogoDark from "./entelligence-dark.png"; import Image from "next/image"; -import Composer from "@assistant-ui/react/ui/composer"; -import ThreadWelcome from "@assistant-ui/react/ui/thread-welcome"; +import { Composer, ThreadWelcome } from "@assistant-ui/react"; function asAsyncIterable(source: ReadableStream): AsyncIterable { return { diff --git a/apps/docs/content/docs/reference/context.mdx b/apps/docs/content/docs/reference/context.mdx index 962ea1687..29aef243c 100644 --- a/apps/docs/content/docs/reference/context.mdx +++ b/apps/docs/content/docs/reference/context.mdx @@ -13,11 +13,9 @@ The context is split into four hierarchies: import { ParametersTable } from "@/components/docs"; import { - AssistantContextValue, AssistantActionsState, AssistantModelConfigState, AssistantToolUIsState, - ThreadContextValue, ThreadState, ThreadMessagesState, ThreadActionsState, @@ -27,35 +25,19 @@ import { ContentPartState, MessageState, MessageUtilsState, - AttachmentContextValue, ComposerAttachmentState, MessageAttachmentState, } from "@/components/docs/parameters/context"; ## Assistant Context -The **Assistant Context** provides access for ModelConfigState and ToolUIsState. - -### `useAssistantContext` - -This hook providers access to the context stores. - -```tsx -import { useAssistantContext } from "@assistant-ui/react"; - -const { useAssistantActions, useModelConfig, useToolUIs } = - useAssistantContext(); -``` - -{" "} - #### `useAssistantActions` ```tsx import { useAssistantActions } from "@assistant-ui/react"; +const actions = useAssistantActions(); const switchToNewThread = useAssistantActions((m) => m.switchToThread); -const switchToNewThread = useAssistantActions.getState().switchToThread; ``` @@ -63,10 +45,10 @@ const switchToNewThread = useAssistantActions.getState().switchToThread; ### `useModelConfig` ```tsx -const { useModelConfig } = useAssistantContext(); +import { useModelConfig } from "@assistant-ui/react"; +const modelConfig = useModelConfig(); const getModelConfig = useModelConfig((m) => m.getModelConfig); -const getModelConfig = useModelConfig.getState().getModelConfig; ``` {" "} @@ -74,47 +56,23 @@ const getModelConfig = useModelConfig.getState().getModelConfig; ### `useToolUIs` ```tsx -const { useToolUIs } = useAssistantContext(); +import { useToolUIs } from "@assistant-ui/react"; +const toolUIs = useToolUIs(); const getToolUI = useToolUIs((m) => m.getToolUI); -const getToolUI = useToolUIs.getState().getToolUI; ``` ## Thread Context -The **Thread Context** provides access to ThreadState, ThreadActionsState, ComposerState and ThreadViewportState. - -### `useThreadContext` - -This hook provides access to the context stores. - -```tsx -import { useThreadContext } from "@assistant-ui/react"; - -const { - useThread, - useThreadMessages, - useThreadActions, - useThreadRuntime, - useComposer, - useViewport, -} = useThreadContext(); -``` - - - ### `useThread` ```tsx -const { useThread } = useThreadContext(); +import { useThread } from "@assistant-ui/react"; +const thread = useThread(); const isRunning = useThread((m) => m.isRunning); -const isRunning = useThread.getState().isRunning; - -const isDisabled = useThread((m) => m.isDisabled); -const isDisabled = useThread.getState().isDisabled; ``` @@ -122,10 +80,10 @@ const isDisabled = useThread.getState().isDisabled; ### `useThreadMessages` ```tsx -const { useThreadMessages } = useThreadContext(); +import { useThreadMessages } from "@assistant-ui/react"; +const messages = useThreadMessages(); const firstMessage = useThreadMessages((m) => m[0]); -const firstMessage = useThreadMessages.getState()[0]; ``` @@ -133,25 +91,10 @@ const firstMessage = useThreadMessages.getState()[0]; ### `useThreadActions` ```tsx -const { useThreadActions } = useThreadContext(); +import { useThreadActions } from "@assistant-ui/react"; +const actions = useThreadActions(); const getBranches = useThreadActions((m) => m.getBranches); -const getBranches = useThreadActions.getState().getBranches; - -const switchToBranch = useThreadActions((m) => m.switchToBranch); -const switchToBranch = useThreadActions.getState().switchToBranch; - -const append = useThreadActions((m) => m.append); -const append = useThreadActions.getState().append; - -const startRun = useThreadActions((m) => m.startRun); -const startRun = useThreadActions.getState().startRun; - -const cancelRun = useThreadActions((m) => m.cancelRun); -const cancelRun = useThreadActions.getState().cancelRun; - -const addToolResult = useThreadActions((m) => m.addToolResult); -const addToolResult = useThreadActions.getState().addToolResult; ``` @@ -159,84 +102,42 @@ const addToolResult = useThreadActions.getState().addToolResult; ### `useThreadRuntime` ```tsx -const { useThreadRuntime } = useThreadContext(); +import { useThreadRuntime } from "@assistant-ui/react"; const runtime = useThreadRuntime(); ``` -### `useComposer` +### `useThreadComposer` ```tsx -const { useComposer } = useThreadContext(); - -const text = useComposer((m) => m.text); -const text = useComposer.getState().text; - -const setText = useComposer((m) => m.setText); -const setText = useComposer.getState().setText; - -const attachments = useComposer((m) => m.attachments); -const attachments = useComposer.getState().attachments; - -const addAttachment = useComposer((m) => m.addAttachment); -const addAttachment = useComposer.getState().addAttachment; - -const removeAttachment = useComposer((m) => m.removeAttachment); -const removeAttachment = useComposer.getState().removeAttachment; +import { useThreadComposer } from "@assistant-ui/react"; -const reset = useComposer((m) => m.reset); -const reset = useComposer.getState().reset; +const composer = useThreadComposer(); +const text = useThreadComposer((m) => m.text); ``` -### `useViewport` +### `useThreadViewport` ```tsx -const { useViewport } = useThreadContext(); +import { useThreadViewport } from "@assistant-ui/react"; -const isAtBottom = useViewport((m) => m.isAtBottom); -const isAtBottom = useViewport.getState().isAtBottom; - -const scrollToBottom = useViewport((m) => m.scrollToBottom); -const scrollToBottom = useViewport.getState().scrollToBottom; - -const onScrollToBottom = useViewport((m) => m.onScrollToBottom); -const onScrollToBottom = useViewport.getState().onScrollToBottom; +const threadViewport = useThreadViewport(); +const isAtBottom = useThreadViewport((m) => m.isAtBottom); ``` ## Message Context -The **Message Context** provides access to MessageState, EditComposerState and MessageUtilsState. - -### `useMessageContext` - -This hook provides access to the context stores. - -```tsx -import { useMessageContext } from "@assistant-ui/react"; - -const { useMessage, useMessageUtils, useEditComposer } = useMessageContext(); -``` - ### `useMessage` ```tsx -const { useMessage } = useMessageContext(); +import { useMessage } from "@assistant-ui/react"; +const { message } = useMessage(); const message = useMessage((m) => m.message); -const message = useMessage.getState().message; - -const parentId = useMessage((m) => m.parentId); -const parentId = useMessage.getState().parentId; - -const branches = useMessage((m) => m.branches); -const branches = useMessage.getState().branches; - -const isLast = useMessage((m) => m.isLast); -const isLast = useMessage.getState().isLast; ``` @@ -244,28 +145,10 @@ const isLast = useMessage.getState().isLast; ### `useMessageUtils` ```tsx -const { useMessageUtils } = useMessageContext(); +import { useMessageUtils } from "@assistant-ui/react"; +const messageUtils = useMessageUtils(); const isCopied = useMessageUtils((m) => m.isCopied); -const isCopied = useMessageUtils.getState().isCopied; - -const setIsCopied = useMessageUtils((m) => m.setIsCopied); -const setIsCopied = useMessageUtils.getState().setIsCopied; - -const isHovering = useMessageUtils((m) => m.isHovering); -const isHovering = useMessageUtils.getState().isHovering; - -const setIsHovering = useMessageUtils((m) => m.setIsHovering); -const setIsHovering = useMessageUtils.getState().setIsHovering; - -const isSpeaking = useMessageUtils((m) => m.isSpeaking); -const isSpeaking = useMessageUtils.getState().isSpeaking; - -const stopSpeaking = useMessageUtils((m) => m.stopSpeaking); -const stopSpeaking = useMessageUtils.getState().stopSpeaking; - -const addUtterance = useMessageUtils((m) => m.addUtterance); -const addUtterance = useMessageUtils.getState().addUtterance; ``` @@ -273,35 +156,20 @@ const addUtterance = useMessageUtils.getState().addUtterance; ### `useEditComposer` ```tsx -const { useEditComposer } = useMessageContext(); +import { useEditComposer } from "@assistant-ui/react"; +const editComposer = useEditComposer(); const text = useEditComposer((m) => m.text); -const text = useEditComposer.getState().text; - -const setText = useEditComposer((m) => m.setText); -const setText = useEditComposer.getState().setText; ``` ## Content Part Context -The **Content Part Context** provides access to ContentPartState. - -### `useContentPartContext` - -This hook provides access to the context stores. - -```tsx -import { useContentPartContext } from "@assistant-ui/react"; - -const { useContentPart } = useContentPartContext(); -``` - ### `useContentPart` ```tsx -const { useContentPart } = useContentPartContext(); +import { useContentPart } from "@assistant-ui/react"; const part = useContentPart((m) => m.part); const part = useContentPart.getState().part; @@ -314,77 +182,48 @@ const status = useContentPart.getState().status; ## Composer Context -The **Composer Context** provides access to the nearest composer state (either the edit composer or the thread's new message composer). - -### `useComposerContext` - -This hook provides access to the context stores. - -```tsx -import { useComposerContext } from "@assistant-ui/react"; - -const { useComposer } = useComposerContext(); -``` - -", - required: true, - description: "The composer state.", - }, - { - name: "type", - type: "'edit' | 'new'", - required: true, - description: "The type of composer.", - }, - ]} -/> +Grabs the nearest composer state (either the edit composer or the thread's new message composer). ### `useComposer` ```tsx -const { useComposer } = useComposerContext(); +import { useComposer } from "@assistant-ui/react"; +const composer = useComposer(); const text = useComposer((m) => m.text); -const text = useComposer.getState().text; - -const setText = useComposer((m) => m.setText); -const setText = useComposer.getState().setText; ``` ## Attachment Context -Provides access to the current attachment state. +Grabs the attachment state (either the composer or message attachment). -### `useAttachmentContext` - -This hook provides access to the context stores. +### `useAttachment` ```tsx -import { useAttachmentContext } from "@assistant-ui/react"; +import { useAttachment } from "@assistant-ui/react"; -const { useAttachment } = useAttachmentContext(); +const { attachment } = useAttachment(); +const attachment = useAttachment((m) => m.attachment); ``` - - -### `useAttachment` +#### `useComposerAttachment` (Composer) ```tsx -const { useAttachment } = useAttachmentContext(); +import { useComposerAttachment } from "@assistant-ui/react"; -const attachment = useAttachment((m) => m.attachment); -const attachment = useAttachment.getState().attachment; +const { attachment } = useComposerAttachment(); +const attachment = useComposerAttachment((m) => m.attachment); ``` -#### `useAttachment` (Composer) - -#### `useAttachment` (Message) +#### `useMessageAttachment` (Message) + +```tsx +import { useMessageAttachment } from "@assistant-ui/react"; + +const { attachment } = useMessageAttachment(); +const attachment = useMessageAttachment((m) => m.attachment); +``` diff --git a/apps/docs/content/docs/runtimes/langgraph/tutorial/part-3.mdx b/apps/docs/content/docs/runtimes/langgraph/tutorial/part-3.mdx index 293203095..5b5a04d67 100644 --- a/apps/docs/content/docs/runtimes/langgraph/tutorial/part-3.mdx +++ b/apps/docs/content/docs/runtimes/langgraph/tutorial/part-3.mdx @@ -87,7 +87,7 @@ Then we use `makeAssistantToolUI` to define the tool UI: import { TransactionConfirmationPending } from "./transaction-confirmation-pending"; import { TransactionConfirmationFinal } from "./transaction-confirmation-final"; -import { makeAssistantToolUI, useThreadContext } from "@assistant-ui/react"; +import { makeAssistantToolUI } from "@assistant-ui/react"; import { updateState } from "@/lib/chatApi"; export const PurchaseStockTool = makeAssistantToolUI( @@ -312,7 +312,7 @@ We will import the new `` component and use it i import { TransactionConfirmationPending } from "./transaction-confirmation-pending"; import { TransactionConfirmationFinal } from "./transaction-confirmation-final"; -import { makeAssistantToolUI, useThreadContext } from "@assistant-ui/react"; +import { makeAssistantToolUI } from "@assistant-ui/react"; import { updateState } from "@/lib/chatApi"; type PurchaseStockArgs = { diff --git a/examples/with-ffmpeg/app/page.tsx b/examples/with-ffmpeg/app/page.tsx index db2ea28af..2aaf443c6 100644 --- a/examples/with-ffmpeg/app/page.tsx +++ b/examples/with-ffmpeg/app/page.tsx @@ -4,7 +4,7 @@ import { Thread, useAssistantInstructions, useAssistantTool, - useThreadContext, + useThreadComposer, } from "@assistant-ui/react"; import { z } from "zod"; import { FFmpeg } from "@ffmpeg/ffmpeg"; @@ -152,8 +152,7 @@ const FfmpegTool: FC<{ file: File }> = ({ file }) => { export default function Home() { const [lastFile, setLastFile] = useState(null); - const { useComposer } = useThreadContext(); - const attachments = useComposer((c) => c.attachments); + const attachments = useThreadComposer((c) => c.attachments); useEffect(() => { const lastAttachment = attachments[attachments.length - 1]; if (!lastAttachment) return; diff --git a/examples/with-inline-suggestions/components/ui/assistant-ui/LastMessageHook.tsx b/examples/with-inline-suggestions/components/ui/assistant-ui/LastMessageHook.tsx index 771a7fd7a..30d8be2b0 100644 --- a/examples/with-inline-suggestions/components/ui/assistant-ui/LastMessageHook.tsx +++ b/examples/with-inline-suggestions/components/ui/assistant-ui/LastMessageHook.tsx @@ -1,8 +1,6 @@ -import { ThreadAssistantMessage } from "@assistant-ui/react"; -import { useThreadContext } from "@assistant-ui/react"; +import { ThreadAssistantMessage, useThreadMessages } from "@assistant-ui/react"; export const useLastAssistantMessage = () => { - const { useThreadMessages } = useThreadContext(); return useThreadMessages((messages) => { for (let i = messages.length - 1; i >= 0; i--) { if (messages[i]?.role === "assistant") { diff --git a/examples/with-langgraph/components/tools/ToolFallback.tsx b/examples/with-langgraph/components/tools/ToolFallback.tsx index 20a446b33..b52087f04 100644 --- a/examples/with-langgraph/components/tools/ToolFallback.tsx +++ b/examples/with-langgraph/components/tools/ToolFallback.tsx @@ -1,8 +1,10 @@ import { ToolCallContentPartComponent } from "@assistant-ui/react"; -import { TooltipIconButton } from "@assistant-ui/react/internal"; +import { INTERNAL } from "@assistant-ui/react"; import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from "lucide-react"; import { useState } from "react"; +const { TooltipIconButton } = INTERNAL; + export const ToolFallback: ToolCallContentPartComponent = ({ part }) => { const [isCollapsed, setIsCollapsed] = useState(true); return ( diff --git a/examples/with-langgraph/components/tools/purchase-stock/PurchaseStockTool.tsx b/examples/with-langgraph/components/tools/purchase-stock/PurchaseStockTool.tsx index 9c9da7ab7..4b127a2b9 100644 --- a/examples/with-langgraph/components/tools/purchase-stock/PurchaseStockTool.tsx +++ b/examples/with-langgraph/components/tools/purchase-stock/PurchaseStockTool.tsx @@ -2,7 +2,7 @@ import { TransactionConfirmationPending } from "./transaction-confirmation-pending"; import { TransactionConfirmationFinal } from "./transaction-confirmation-final"; -import { makeAssistantToolUI, useThreadContext } from "@assistant-ui/react"; +import { makeAssistantToolUI, useThreadStore } from "@assistant-ui/react"; import { updateState } from "@/lib/chatApi"; type PurchaseStockArgs = { @@ -30,10 +30,9 @@ export const PurchaseStockTool = makeAssistantToolUI( ? (JSON.parse(result) as { transactionId: string }) : undefined; - const { useThread } = useThreadContext(); - + const threadStore = useThreadStore(); const handleConfirm = async () => { - await updateState(useThread.getState().threadId, { + await updateState(threadStore.getState().threadId, { newState: CONFIRM_PURCHASE, asNode: PREPARE_PURCHASE_DETAILS_NODE, }); diff --git a/examples/with-react-hook-form/components/ui/form.tsx b/examples/with-react-hook-form/components/ui/form.tsx index fe4b9bfe0..848bbc512 100644 --- a/examples/with-react-hook-form/components/ui/form.tsx +++ b/examples/with-react-hook-form/components/ui/form.tsx @@ -1,6 +1,5 @@ import type * as LabelPrimitive from "@radix-ui/react-label"; import { Slot } from "@radix-ui/react-slot"; -import * as React from "react"; import { Controller, type ControllerProps, @@ -12,6 +11,15 @@ import { import { Label } from "@/components/ui/label"; import { cn } from "@/lib/utils"; +import { + ComponentPropsWithoutRef, + createContext, + ElementRef, + forwardRef, + HTMLAttributes, + useContext, + useId, +} from "react"; const Form = FormProvider; @@ -22,7 +30,7 @@ type FormFieldContextValue< name: TName; }; -const FormFieldContext = React.createContext( +const FormFieldContext = createContext( {} as FormFieldContextValue, ); @@ -40,8 +48,8 @@ const FormField = < }; const useFormField = () => { - const fieldContext = React.useContext(FormFieldContext); - const itemContext = React.useContext(FormItemContext); + const fieldContext = useContext(FormFieldContext); + const itemContext = useContext(FormItemContext); const { getFieldState, formState } = useFormContext(); const fieldState = getFieldState(fieldContext.name, formState); @@ -66,27 +74,26 @@ type FormItemContextValue = { id: string; }; -const FormItemContext = React.createContext( +const FormItemContext = createContext( {} as FormItemContextValue, ); -const FormItem = React.forwardRef< - HTMLDivElement, - React.HTMLAttributes ->(({ className, ...props }, ref) => { - const id = React.useId(); +const FormItem = forwardRef>( + ({ className, ...props }, ref) => { + const id = useId(); - return ( - -
- - ); -}); + return ( + +
+ + ); + }, +); FormItem.displayName = "FormItem"; -const FormLabel = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef +const FormLabel = forwardRef< + ElementRef, + ComponentPropsWithoutRef >(({ className, ...props }, ref) => { const { error, formItemId } = useFormField(); @@ -101,9 +108,9 @@ const FormLabel = React.forwardRef< }); FormLabel.displayName = "FormLabel"; -const FormControl = React.forwardRef< - React.ElementRef, - React.ComponentPropsWithoutRef +const FormControl = forwardRef< + ElementRef, + ComponentPropsWithoutRef >(({ ...props }, ref) => { const { error, formItemId, formDescriptionId, formMessageId } = useFormField(); @@ -124,9 +131,9 @@ const FormControl = React.forwardRef< }); FormControl.displayName = "FormControl"; -const FormDescription = React.forwardRef< +const FormDescription = forwardRef< HTMLParagraphElement, - React.HTMLAttributes + HTMLAttributes >(({ className, ...props }, ref) => { const { formDescriptionId } = useFormField(); @@ -141,9 +148,9 @@ const FormDescription = React.forwardRef< }); FormDescription.displayName = "FormDescription"; -const FormMessage = React.forwardRef< +const FormMessage = forwardRef< HTMLParagraphElement, - React.HTMLAttributes + HTMLAttributes >(({ className, children, ...props }, ref) => { const { error, formMessageId } = useFormField(); const body = error ? String(error?.message) : children; diff --git a/packages/react-hook-form/src/useAssistantForm.tsx b/packages/react-hook-form/src/useAssistantForm.tsx index ecfe4e69f..87b4ceb77 100644 --- a/packages/react-hook-form/src/useAssistantForm.tsx +++ b/packages/react-hook-form/src/useAssistantForm.tsx @@ -3,8 +3,8 @@ import { type ModelConfig, type ToolCallContentPartComponent, - useAssistantContext, useAssistantToolUI, + useModelConfigStore, } from "@assistant-ui/react"; import { useEffect } from "react"; import { @@ -62,11 +62,7 @@ export const useAssistantForm = < const form = useForm(props); const { control, getValues, setValue } = form; - const { useModelConfig } = useAssistantContext(); - const registerModelConfigProvider = useModelConfig( - (c) => c.registerModelConfigProvider, - ); - + const modelConfigStore = useModelConfigStore(); useEffect(() => { const value: ModelConfig = { system: `Form State:\n${JSON.stringify(getValues())}`, @@ -111,8 +107,10 @@ export const useAssistantForm = < }, }, }; - return registerModelConfigProvider({ getModelConfig: () => value }); - }, [control, setValue, getValues, registerModelConfigProvider]); + return modelConfigStore.getState().registerModelConfigProvider({ + getModelConfig: () => value, + }); + }, [control, setValue, getValues, modelConfigStore]); const renderFormFieldTool = props?.assistant?.tools?.set_form_field?.render; useAssistantToolUI( diff --git a/packages/react-playground/src/components/ui/assistant-ui/assistant-playground.tsx b/packages/react-playground/src/components/ui/assistant-ui/assistant-playground.tsx index 6ed8b2295..82e23a413 100644 --- a/packages/react-playground/src/components/ui/assistant-ui/assistant-playground.tsx +++ b/packages/react-playground/src/components/ui/assistant-ui/assistant-playground.tsx @@ -1,7 +1,7 @@ "use client"; import { ChangeEvent, FC, PropsWithChildren, useState } from "react"; -import { Tool, useAssistantContext } from "@assistant-ui/react"; +import { Tool, useAssistantActionsStore } from "@assistant-ui/react"; import { PayloadEditorButton } from "../../payload-editor-button"; import { Thread } from "./thread"; import { Button } from "../button"; @@ -322,9 +322,9 @@ const Sidebar: FC = ({ modelSelector, apiKey = true, }) => { - const { useAssistantActions } = useAssistantContext(); + const assistantActionsStore = useAssistantActionsStore(); const handleReset = () => { - useAssistantActions.getState().switchToThread(null); + assistantActionsStore.getState().switchToThread(null); }; return ( diff --git a/packages/react-playground/src/components/ui/assistant-ui/remove-content-part.tsx b/packages/react-playground/src/components/ui/assistant-ui/remove-content-part.tsx index 2a4e26f31..802ac2d4f 100644 --- a/packages/react-playground/src/components/ui/assistant-ui/remove-content-part.tsx +++ b/packages/react-playground/src/components/ui/assistant-ui/remove-content-part.tsx @@ -1,4 +1,4 @@ -import { useContentPartContext, useMessageContext } from "@assistant-ui/react"; +import { useContentPartStore, useMessageStore } from "@assistant-ui/react"; import { TooltipIconButton } from "./tooltip-icon-button"; import { CircleXIcon } from "lucide-react"; import { FC } from "react"; @@ -6,12 +6,13 @@ import { useGetPlaygroundRuntime } from "../../../lib/usePlaygroundRuntime"; export const RemoveContentPartButton: FC = () => { const getPlaygroundRuntime = useGetPlaygroundRuntime(); - const { useMessage } = useMessageContext(); - const { useContentPart } = useContentPartContext(); + + const messageStore = useMessageStore(); + const contentPartStore = useContentPartStore(); const handleRemove = () => { getPlaygroundRuntime().deleteContentPart( - useMessage.getState().message.id, - useContentPart.getState().part, + messageStore.getState().message.id, + contentPartStore.getState().part, ); }; diff --git a/packages/react-playground/src/components/ui/assistant-ui/text.tsx b/packages/react-playground/src/components/ui/assistant-ui/text.tsx index 7c05270eb..25695ffa3 100644 --- a/packages/react-playground/src/components/ui/assistant-ui/text.tsx +++ b/packages/react-playground/src/components/ui/assistant-ui/text.tsx @@ -1,17 +1,14 @@ -import { - TextContentPartComponent, - useMessageContext, -} from "@assistant-ui/react"; +import { TextContentPartComponent, useMessageStore } from "@assistant-ui/react"; import TextareaAutosize from "react-textarea-autosize"; import { useGetPlaygroundRuntime } from "../../../lib/usePlaygroundRuntime"; export const Text: TextContentPartComponent = ({ part }) => { const getPlaygroundRuntime = useGetPlaygroundRuntime(); - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); const handleChange = (e: React.ChangeEvent) => { try { getPlaygroundRuntime().setMessageText({ - messageId: useMessage.getState().message.id, + messageId: messageStore.getState().message.id, contentPart: part, text: e.target.value, }); diff --git a/packages/react-playground/src/components/ui/assistant-ui/thread.tsx b/packages/react-playground/src/components/ui/assistant-ui/thread.tsx index 8e3528ac1..0cbe7fc1c 100644 --- a/packages/react-playground/src/components/ui/assistant-ui/thread.tsx +++ b/packages/react-playground/src/components/ui/assistant-ui/thread.tsx @@ -5,8 +5,12 @@ import { ComposerPrimitive, MessagePrimitive, ThreadPrimitive, - useMessageContext, - useThreadContext, + useComposer, + useComposerStore, + useMessage, + useMessageStore, + useThread, + useThreadActionsStore, } from "@assistant-ui/react"; import { useState, type FC, type KeyboardEvent, type MouseEvent } from "react"; import { @@ -101,13 +105,15 @@ const RoleSelect: FC = ({ role, setRole, children }) => { const Composer: FC = () => { const [role, setRole] = useState<"user" | "assistant" | "system">("user"); - const { useThread, useThreadActions, useComposer } = useThreadContext(); const isRunning = useThread((t) => t.isRunning); const hasText = useComposer((c) => c.text.length > 0); + const threadActionsStore = useThreadActionsStore(); + const composerStore = useComposerStore(); + const performAdd = () => { - const composer = useComposer.getState(); + const composer = composerStore.getState(); composer.send(); setRole("user"); @@ -115,7 +121,7 @@ const Composer: FC = () => { const performSubmit = () => { performAdd(); - useThreadActions.getState().startRun(null); + threadActionsStore.getState().startRun(null); }; const handleAdd = (e: MouseEvent) => { @@ -180,8 +186,8 @@ const Composer: FC = () => { }; const AddToolCallButton = () => { + const messageStore = useMessageStore(); const runtime = usePlaygroundRuntime(); - const { useMessage } = useMessageContext(); const toolNames = runtime.useModelConfig((c) => Object.keys(c.tools ?? {})); return ( @@ -201,7 +207,7 @@ const AddToolCallButton = () => { className="gap-2" onClick={() => { runtime.addTool({ - messageId: useMessage.getState().message.id, + messageId: messageStore.getState().message.id, toolName, }); }} @@ -217,14 +223,14 @@ const AddToolCallButton = () => { const AddImageButton = () => { const getPlaygroundRuntime = useGetPlaygroundRuntime(); - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); const [isOpen, setIsOpen] = useState(false); const [imageUrl, setImageUrl] = useState(""); const handleAddImage = () => { getPlaygroundRuntime().addImage({ image: new URL(imageUrl).href, - messageId: useMessage.getState().message.id, + messageId: messageStore.getState().message.id, }); setIsOpen(false); }; @@ -259,19 +265,19 @@ const AddImageButton = () => { const Message: FC = () => { const getPlaygroundRuntime = useGetPlaygroundRuntime(); - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); const role = useMessage((m) => m.message.role); const status = useMessage((m) => m.message.role === "assistant" ? m.message.status : null, ); const handleDelete = () => { - getPlaygroundRuntime().deleteMessage(useMessage.getState().message.id); + getPlaygroundRuntime().deleteMessage(messageStore.getState().message.id); }; const setRole = (role: "system" | "assistant" | "user") => { getPlaygroundRuntime().setRole({ - messageId: useMessage.getState().message.id, + messageId: messageStore.getState().message.id, role, }); }; diff --git a/packages/react-playground/src/components/ui/assistant-ui/tool-ui.tsx b/packages/react-playground/src/components/ui/assistant-ui/tool-ui.tsx index ca481a69c..7d407de03 100644 --- a/packages/react-playground/src/components/ui/assistant-ui/tool-ui.tsx +++ b/packages/react-playground/src/components/ui/assistant-ui/tool-ui.tsx @@ -1,7 +1,7 @@ import { ToolCallContentPart, ToolCallContentPartComponent, - useContentPartContext, + useContentPart, } from "@assistant-ui/react"; import { FC, useEffect, useState } from "react"; import { CornerDownRightIcon } from "lucide-react"; @@ -30,7 +30,6 @@ export const ToolUI: ToolCallContentPartComponent = ({ part }) => { }; const useContentPartTool = () => { - const { useContentPart } = useContentPartContext(); const part = useContentPart((c) => c.part); return part as ToolCallContentPart; }; diff --git a/packages/react-playground/src/lib/usePlaygroundRuntime.tsx b/packages/react-playground/src/lib/usePlaygroundRuntime.tsx index e8b8a2c42..b4292171f 100644 --- a/packages/react-playground/src/lib/usePlaygroundRuntime.tsx +++ b/packages/react-playground/src/lib/usePlaygroundRuntime.tsx @@ -1,13 +1,11 @@ "use client"; -import { useThreadContext } from "@assistant-ui/react"; +import { useThreadRuntime, useThreadRuntimeStore } from "@assistant-ui/react"; import { PlaygroundThreadRuntime } from "./playground-runtime"; export const useGetPlaygroundRuntime = () => { - const { useThreadRuntime } = useThreadContext(); - return useThreadRuntime.getState as () => PlaygroundThreadRuntime; + return useThreadRuntimeStore().getState as () => PlaygroundThreadRuntime; }; export const usePlaygroundRuntime = () => { - const { useThreadRuntime } = useThreadContext(); return useThreadRuntime() as PlaygroundThreadRuntime; }; diff --git a/packages/react/src/context/ReadonlyStore.ts b/packages/react/src/context/ReadonlyStore.ts index 2d763f202..1a824f6b2 100644 --- a/packages/react/src/context/ReadonlyStore.ts +++ b/packages/react/src/context/ReadonlyStore.ts @@ -1,9 +1,7 @@ -import type { StoreApi, UseBoundStore } from "zustand"; +import type { StoreApi } from "zustand"; -export type ReadonlyStore = UseBoundStore< - Omit, "setState" | "destroy"> ->; +export type ReadonlyStore = Omit, "setState" | "destroy">; export const writableStore = (store: ReadonlyStore | undefined) => { - return store as unknown as UseBoundStore>; + return store as unknown as StoreApi; }; diff --git a/packages/react/src/context/providers/AssistantProvider.tsx b/packages/react/src/context/providers/AssistantProvider.tsx index ca0a7baa3..b1bb4c636 100644 --- a/packages/react/src/context/providers/AssistantProvider.tsx +++ b/packages/react/src/context/providers/AssistantProvider.tsx @@ -37,10 +37,10 @@ export const AssistantProvider: FC< }; }); - const getModelConfig = context.useModelConfig(); + const modelConfigProvider = context.useModelConfig(); useEffect(() => { - return runtime.registerModelConfigProvider(getModelConfig); - }, [runtime, getModelConfig]); + return runtime.registerModelConfigProvider(modelConfigProvider); + }, [runtime, modelConfigProvider]); useEffect( () => writableStore(context.useAssistantRuntime).setState(runtime, true), diff --git a/packages/react/src/context/providers/ComposerAttachmentProvider.tsx b/packages/react/src/context/providers/ComposerAttachmentProvider.tsx index 7bd476bfc..edde1b553 100644 --- a/packages/react/src/context/providers/ComposerAttachmentProvider.tsx +++ b/packages/react/src/context/providers/ComposerAttachmentProvider.tsx @@ -2,21 +2,21 @@ import { type FC, type PropsWithChildren, useEffect, useState } from "react"; import { create } from "zustand"; -import type { ComposerState } from "../stores"; -import { useThreadContext } from "../react"; +import type { ThreadComposerState } from "../stores"; import { ComposerAttachmentState } from "../stores/Attachment"; import { AttachmentContext, AttachmentContextValue, } from "../react/AttachmentContext"; import { writableStore } from "../ReadonlyStore"; +import { useThreadComposerStore } from "../react/ThreadContext"; type ComposerAttachmentProviderProps = PropsWithChildren<{ attachmentIndex: number; }>; const getAttachment = ( - { attachments }: ComposerState, + { attachments }: ThreadComposerState, useAttachment: AttachmentContextValue["useAttachment"] | undefined, partIndex: number, ) => { @@ -31,11 +31,12 @@ const getAttachment = ( }; const useComposerAttachmentContext = (partIndex: number) => { - const { useComposer } = useThreadContext(); + const threadComposerStore = useThreadComposerStore(); const [context] = useState( () => { const useAttachment = create( - () => getAttachment(useComposer.getState(), undefined, partIndex)!, + () => + getAttachment(threadComposerStore.getState(), undefined, partIndex)!, ); return { type: "composer", useAttachment }; @@ -43,7 +44,7 @@ const useComposerAttachmentContext = (partIndex: number) => { ); useEffect(() => { - const syncAttachment = (composer: ComposerState) => { + const syncAttachment = (composer: ThreadComposerState) => { const newState = getAttachment( composer, context.useAttachment, @@ -53,9 +54,9 @@ const useComposerAttachmentContext = (partIndex: number) => { writableStore(context.useAttachment).setState(newState, true); }; - syncAttachment(useComposer.getState()); - return useComposer.subscribe(syncAttachment); - }, [context, useComposer, partIndex]); + syncAttachment(threadComposerStore.getState()); + return threadComposerStore.subscribe(syncAttachment); + }, [context, threadComposerStore, partIndex]); return context; }; diff --git a/packages/react/src/context/providers/ContentPartProvider.tsx b/packages/react/src/context/providers/ContentPartProvider.tsx index 4fe1817bc..22a889078 100644 --- a/packages/react/src/context/providers/ContentPartProvider.tsx +++ b/packages/react/src/context/providers/ContentPartProvider.tsx @@ -4,7 +4,7 @@ import { type FC, type PropsWithChildren, useEffect, useState } from "react"; import { create } from "zustand"; import { ContentPartContext } from "../react/ContentPartContext"; import type { ContentPartContextValue } from "../react/ContentPartContext"; -import { useMessageContext } from "../react/MessageContext"; +import { useMessageStore } from "../react/MessageContext"; import type { MessageState } from "../stores"; import type { ContentPartState } from "../stores/ContentPart"; import { @@ -90,10 +90,10 @@ const getContentPartState = ( }; const useContentPartContext = (partIndex: number) => { - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); const [context] = useState(() => { const useContentPart = create( - () => getContentPartState(useMessage.getState(), undefined, partIndex)!, + () => getContentPartState(messageStore.getState(), undefined, partIndex)!, ); return { useContentPart }; @@ -110,9 +110,9 @@ const useContentPartContext = (partIndex: number) => { writableStore(context.useContentPart).setState(newState, true); }; - syncContentPart(useMessage.getState()); - return useMessage.subscribe(syncContentPart); - }, [context, useMessage, partIndex]); + syncContentPart(messageStore.getState()); + return messageStore.subscribe(syncContentPart); + }, [context, messageStore, partIndex]); return context; }; diff --git a/packages/react/src/context/providers/MessageAttachmentProvider.tsx b/packages/react/src/context/providers/MessageAttachmentProvider.tsx index 929b2e56e..9defd4cb2 100644 --- a/packages/react/src/context/providers/MessageAttachmentProvider.tsx +++ b/packages/react/src/context/providers/MessageAttachmentProvider.tsx @@ -3,7 +3,7 @@ import { type FC, type PropsWithChildren, useEffect, useState } from "react"; import { create } from "zustand"; import type { MessageState } from "../stores"; -import { useMessageContext } from "../react"; +import { useMessageStore } from "../react"; import { MessageAttachmentState } from "../stores/Attachment"; import { AttachmentContext, @@ -34,11 +34,11 @@ const getAttachment = ( }; const useMessageAttachmentContext = (partIndex: number) => { - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); const [context] = useState( () => { const useAttachment = create( - () => getAttachment(useMessage.getState(), undefined, partIndex)!, + () => getAttachment(messageStore.getState(), undefined, partIndex)!, ); return { type: "message", useAttachment }; @@ -56,9 +56,9 @@ const useMessageAttachmentContext = (partIndex: number) => { writableStore(context.useAttachment).setState(newState, true); }; - syncAttachment(useMessage.getState()); - return useMessage.subscribe(syncAttachment); - }, [context, useMessage, partIndex]); + syncAttachment(messageStore.getState()); + return messageStore.subscribe(syncAttachment); + }, [context, messageStore, partIndex]); return context; }; diff --git a/packages/react/src/context/providers/MessageProvider.tsx b/packages/react/src/context/providers/MessageProvider.tsx index b969c9b10..f33572cb2 100644 --- a/packages/react/src/context/providers/MessageProvider.tsx +++ b/packages/react/src/context/providers/MessageProvider.tsx @@ -9,7 +9,10 @@ import type { import { getThreadMessageText } from "../../utils/getThreadMessageText"; import { MessageContext } from "../react/MessageContext"; import type { MessageContextValue } from "../react/MessageContext"; -import { useThreadContext } from "../react/ThreadContext"; +import { + useThreadActionsStore, + useThreadMessagesStore, +} from "../react/ThreadContext"; import type { MessageState } from "../stores/Message"; import { makeEditComposerStore } from "../stores/EditComposer"; import { makeMessageUtilsStore } from "../stores/MessageUtils"; @@ -57,14 +60,14 @@ const getMessageState = ( }; const useMessageContext = (messageIndex: number) => { - const { useThreadMessages, useThreadActions } = useThreadContext(); - + const threadMessagesStore = useThreadMessagesStore(); + const threadActionsStore = useThreadActionsStore(); const [context] = useState(() => { const useMessage = create( () => getMessageState( - useThreadMessages.getState(), - useThreadActions.getState().getBranches, + threadMessagesStore.getState(), + threadActionsStore.getState().getBranches, undefined, messageIndex, )!, @@ -88,7 +91,7 @@ const useMessageContext = (messageIndex: number) => { ); // TODO fix types here - useThreadActions.getState().append({ + threadActionsStore.getState().append({ parentId, role: message.role, content: [{ type: "text", text }, ...nonTextParts] as any, @@ -104,7 +107,7 @@ const useMessageContext = (messageIndex: number) => { const syncMessage = (thread: ThreadMessagesState) => { const newState = getMessageState( thread, - useThreadActions.getState().getBranches, + threadActionsStore.getState().getBranches, context.useMessage, messageIndex, ); @@ -112,10 +115,10 @@ const useMessageContext = (messageIndex: number) => { writableStore(context.useMessage).setState(newState, true); }; - syncMessage(useThreadMessages.getState()); + syncMessage(threadMessagesStore.getState()); - return useThreadMessages.subscribe(syncMessage); - }, [useThreadMessages, useThreadActions, context, messageIndex]); + return threadMessagesStore.subscribe(syncMessage); + }, [threadMessagesStore, threadActionsStore, context, messageIndex]); return context; }; diff --git a/packages/react/src/context/providers/ThreadProvider.tsx b/packages/react/src/context/providers/ThreadProvider.tsx index 2a5d54eed..076847afa 100644 --- a/packages/react/src/context/providers/ThreadProvider.tsx +++ b/packages/react/src/context/providers/ThreadProvider.tsx @@ -3,7 +3,7 @@ import { useEffect, useInsertionEffect, useState } from "react"; import type { ReactThreadRuntime } from "../../runtimes/core/ReactThreadRuntime"; import type { ThreadContextValue } from "../react/ThreadContext"; import { ThreadContext } from "../react/ThreadContext"; -import { makeComposerStore } from "../stores/Composer"; +import { makeThreadComposerStore } from "../stores/ThreadComposer"; import { getThreadStateFromRuntime, makeThreadStore } from "../stores/Thread"; import { makeThreadViewportStore } from "../stores/ThreadViewport"; import { makeThreadActionStore } from "../stores/ThreadActions"; @@ -27,7 +27,7 @@ export const ThreadProvider: FC> = ({ const useThreadMessages = makeThreadMessagesStore(useThreadRuntime); const useThreadActions = makeThreadActionStore(useThreadRuntime); const useViewport = makeThreadViewportStore(); - const useComposer = makeComposerStore(useThreadRuntime); + const useComposer = makeThreadComposerStore(useThreadRuntime); return { useThread, diff --git a/packages/react/src/context/react/AssistantContext.ts b/packages/react/src/context/react/AssistantContext.ts index 72d816226..fc66a7124 100644 --- a/packages/react/src/context/react/AssistantContext.ts +++ b/packages/react/src/context/react/AssistantContext.ts @@ -1,32 +1,43 @@ "use client"; -import { createContext, useContext } from "react"; +import { createContext } from "react"; import type { AssistantModelConfigState } from "../stores/AssistantModelConfig"; import type { AssistantToolUIsState } from "../stores/AssistantToolUIs"; import { ReadonlyStore } from "../ReadonlyStore"; import { AssistantActionsState } from "../stores/AssistantActions"; import { AssistantRuntime } from "../../runtimes"; +import { createContextHook } from "./utils/createContextHook"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; +import { UseBoundStore } from "zustand"; export type AssistantContextValue = { - useModelConfig: ReadonlyStore; - useToolUIs: ReadonlyStore; - useAssistantRuntime: ReadonlyStore; - useAssistantActions: ReadonlyStore; + useModelConfig: UseBoundStore>; + useToolUIs: UseBoundStore>; + useAssistantRuntime: UseBoundStore>; + useAssistantActions: UseBoundStore>; }; export const AssistantContext = createContext( null, ); -export function useAssistantContext(): AssistantContextValue; -export function useAssistantContext(options: { - optional: true; -}): AssistantContextValue | null; -export function useAssistantContext(options?: { optional: true }) { - const context = useContext(AssistantContext); - if (!options?.optional && !context) - throw new Error( - "This component must be used within an AssistantRuntimeProvider.", - ); - return context; -} +export const useAssistantContext = createContextHook( + AssistantContext, + "AssistantRuntimeProvider", +); + +export const { useAssistantRuntime, useAssistantRuntimeStore } = + createContextStoreHook(useAssistantContext, "useAssistantRuntime"); + +export const { useModelConfig, useModelConfigStore } = createContextStoreHook( + useAssistantContext, + "useModelConfig", +); + +export const { useToolUIs, useToolUIsStore } = createContextStoreHook( + useAssistantContext, + "useToolUIs", +); + +export const { useAssistantActions, useAssistantActionsStore } = + createContextStoreHook(useAssistantContext, "useAssistantActions"); diff --git a/packages/react/src/context/react/AttachmentContext.ts b/packages/react/src/context/react/AttachmentContext.ts index 93c68a8ad..a56f64360 100644 --- a/packages/react/src/context/react/AttachmentContext.ts +++ b/packages/react/src/context/react/AttachmentContext.ts @@ -6,6 +6,7 @@ import { ComposerAttachmentState, MessageAttachmentState, } from "../stores/Attachment"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; export type AttachmentContextValue = { type: "composer" | "message"; @@ -28,33 +29,63 @@ export const AttachmentContext = createContext( null, ); -export function useAttachmentContext(): AttachmentContextValue; -export function useAttachmentContext(options: { - type: "composer"; -}): ComposerAttachmentContextValue; -export function useAttachmentContext(options: { - type: "message"; -}): MessageAttachmentContextValue; -export function useAttachmentContext(options: { - optional: true; +export function useAttachmentContext(options?: { + optional?: false | undefined; +}): AttachmentContextValue; +export function useAttachmentContext(options?: { + optional?: boolean | undefined; }): AttachmentContextValue | null; export function useAttachmentContext(options?: { - type?: AttachmentContextValue["type"]; - optional?: true; + optional?: boolean | undefined; }) { const context = useContext(AttachmentContext); - if (options?.type === "composer" && context?.type !== "composer") + if (!options?.optional && !context) throw new Error( - "This component must be used within a ComposerPrimitive.Attachments component.", + "This component must be used within a ComposerPrimitive.Attachments or MessagePrimitive.Attachments component.", ); - if (options?.type === "message" && context?.type !== "message") + + return context; +} + +function useComposerAttachmentContext(): ComposerAttachmentContextValue; +function useComposerAttachmentContext(options: { + optional: true; +}): ComposerAttachmentContextValue | null; +function useComposerAttachmentContext(options?: { optional?: true }) { + const context = useAttachmentContext(options); + if (!context) return null; + if (context.type !== "composer") throw new Error( - "This component must be used within a MessagePrimitive.Attachments component.", + "This component must be used within a ComposerPrimitive.Attachments component.", ); - if (!options?.optional && !context) + return context; +} + +function useMessageAttachmentContext(): MessageAttachmentContextValue; +function useMessageAttachmentContext(options: { + optional: true; +}): MessageAttachmentContextValue | null; +function useMessageAttachmentContext(options?: { optional?: true }) { + const context = useAttachmentContext(options); + if (!context) return null; + if (context.type !== "message") throw new Error( - "This component must be used within a ComposerPrimitive.Attachments or MessagePrimitive.Attachments component.", + "This component must be used within a MessagePrimitive.Attachments component.", ); - return context; } + +export const { useAttachment, useAttachmentStore } = createContextStoreHook( + useAttachmentContext, + "useAttachment", +); + +export const { + useAttachment: useComposerAttachment, + useAttachmentStore: useComposerAttachmentStore, +} = createContextStoreHook(useComposerAttachmentContext, "useAttachment"); + +export const { + useAttachment: useMessageAttachment, + useAttachmentStore: useMessageAttachmentStore, +} = createContextStoreHook(useMessageAttachmentContext, "useAttachment"); diff --git a/packages/react/src/context/react/ComposerContext.ts b/packages/react/src/context/react/ComposerContext.ts index 070a97dea..e2757582a 100644 --- a/packages/react/src/context/react/ComposerContext.ts +++ b/packages/react/src/context/react/ComposerContext.ts @@ -1,12 +1,13 @@ import { useMemo } from "react"; import { useMessageContext } from "./MessageContext"; import { useThreadContext } from "./ThreadContext"; -import type { ComposerState } from "../stores/Composer"; +import type { ThreadComposerState } from "../stores/ThreadComposer"; import type { EditComposerState } from "../stores/EditComposer"; import { ReadonlyStore } from "../ReadonlyStore"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; export type ComposerContextValue = { - useComposer: ReadonlyStore; + useComposer: ReadonlyStore; type: "edit" | "new"; }; @@ -16,10 +17,15 @@ export const useComposerContext = (): ComposerContextValue => { return useMemo( () => ({ useComposer: (useEditComposer ?? useComposer) as ReadonlyStore< - EditComposerState | ComposerState + EditComposerState | ThreadComposerState >, type: useEditComposer ? ("edit" as const) : ("new" as const), }), [useEditComposer, useComposer], ); }; + +export const { useComposer, useComposerStore } = createContextStoreHook( + useComposerContext, + "useComposer", +); diff --git a/packages/react/src/context/react/ContentPartContext.ts b/packages/react/src/context/react/ContentPartContext.ts index c6232d745..d1504d592 100644 --- a/packages/react/src/context/react/ContentPartContext.ts +++ b/packages/react/src/context/react/ContentPartContext.ts @@ -1,26 +1,26 @@ "use client"; -import { createContext, useContext } from "react"; +import { createContext } from "react"; import type { ContentPartState } from "../stores/ContentPart"; import { ReadonlyStore } from "../ReadonlyStore"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; +import { createContextHook } from "./utils/createContextHook"; +import { UseBoundStore } from "zustand"; export type ContentPartContextValue = { - useContentPart: ReadonlyStore; + useContentPart: UseBoundStore>; }; export const ContentPartContext = createContext( null, ); -export function useContentPartContext(): ContentPartContextValue; -export function useContentPartContext(options: { - optional: true; -}): ContentPartContextValue | null; -export function useContentPartContext(options?: { optional: true }) { - const context = useContext(ContentPartContext); - if (!options?.optional && !context) - throw new Error( - "This component can only be used inside a component passed to .", - ); - return context; -} +export const useContentPartContext = createContextHook( + ContentPartContext, + "a component passed to ", +); + +export const { useContentPart, useContentPartStore } = createContextStoreHook( + useContentPartContext, + "useContentPart", +); diff --git a/packages/react/src/context/react/MessageContext.ts b/packages/react/src/context/react/MessageContext.ts index 52f2b5836..3f1a822bd 100644 --- a/packages/react/src/context/react/MessageContext.ts +++ b/packages/react/src/context/react/MessageContext.ts @@ -1,28 +1,39 @@ "use client"; -import { createContext, useContext } from "react"; +import { createContext } from "react"; import type { MessageState } from "../stores/Message"; import type { EditComposerState } from "../stores/EditComposer"; import { ReadonlyStore } from "../ReadonlyStore"; import { MessageUtilsState } from "../stores/MessageUtils"; +import { createContextHook } from "./utils/createContextHook"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; +import { UseBoundStore } from "zustand"; export type MessageContextValue = { - useMessage: ReadonlyStore; - useMessageUtils: ReadonlyStore; - useEditComposer: ReadonlyStore; + useMessage: UseBoundStore>; + useMessageUtils: UseBoundStore>; + useEditComposer: UseBoundStore>; }; export const MessageContext = createContext(null); -export function useMessageContext(): MessageContextValue; -export function useMessageContext(options: { - optional: true; -}): MessageContextValue | null; -export function useMessageContext(options?: { optional: true }) { - const context = useContext(MessageContext); - if (!options?.optional && !context) - throw new Error( - "This component can only be used inside a component passed to .", - ); - return context; -} +export const useMessageContext = createContextHook( + MessageContext, + "a component passed to ", +); + +// TODO make this only return the message itself? +export const { useMessage, useMessageStore } = createContextStoreHook( + useMessageContext, + "useMessage", +); + +export const { useMessageUtils, useMessageUtilsStore } = createContextStoreHook( + useMessageContext, + "useMessageUtils", +); + +export const { useEditComposer, useEditComposerStore } = createContextStoreHook( + useMessageContext, + "useEditComposer", +); diff --git a/packages/react/src/context/react/ThreadContext.ts b/packages/react/src/context/react/ThreadContext.ts index b5e787456..b444f971d 100644 --- a/packages/react/src/context/react/ThreadContext.ts +++ b/packages/react/src/context/react/ThreadContext.ts @@ -1,34 +1,53 @@ "use client"; -import { createContext, useContext } from "react"; -import type { ComposerState } from "../stores/Composer"; +import { createContext } from "react"; +import type { ThreadComposerState } from "../stores/ThreadComposer"; import type { ThreadState } from "../stores/Thread"; import type { ThreadViewportState } from "../stores/ThreadViewport"; import { ThreadActionsState } from "../stores/ThreadActions"; import { ReadonlyStore } from "../ReadonlyStore"; import { ThreadMessagesState } from "../stores/ThreadMessages"; import { ThreadRuntimeStore } from "../stores/ThreadRuntime"; +import { UseBoundStore } from "zustand"; +import { createContextHook } from "./utils/createContextHook"; +import { createContextStoreHook } from "./utils/createContextStoreHook"; export type ThreadContextValue = { - useThread: ReadonlyStore; - useThreadRuntime: ReadonlyStore; - useThreadMessages: ReadonlyStore; - useThreadActions: ReadonlyStore; - useComposer: ReadonlyStore; - useViewport: ReadonlyStore; + useThread: UseBoundStore>; + useThreadRuntime: UseBoundStore>; + useThreadMessages: UseBoundStore>; + useThreadActions: UseBoundStore>; + useComposer: UseBoundStore>; + useViewport: UseBoundStore>; }; export const ThreadContext = createContext(null); -export function useThreadContext(): ThreadContextValue; -export function useThreadContext(options: { - optional: true; -}): ThreadContextValue | null; -export function useThreadContext(options?: { optional: true }) { - const context = useContext(ThreadContext); - if (!options?.optional && !context) - throw new Error( - "This component must be used within an AssistantRuntimeProvider.", - ); - return context; -} +export const useThreadContext = createContextHook( + ThreadContext, + "AssistantRuntimeProvider", +); + +export const { useThreadRuntime, useThreadRuntimeStore } = + createContextStoreHook(useThreadContext, "useThreadRuntime"); + +export const { useThread, useThreadStore } = createContextStoreHook( + useThreadContext, + "useThread", +); + +export const { useThreadMessages, useThreadMessagesStore } = + createContextStoreHook(useThreadContext, "useThreadMessages"); + +export const { useThreadActions, useThreadActionsStore } = + createContextStoreHook(useThreadContext, "useThreadActions"); + +export const { + useComposer: useThreadComposer, + useComposerStore: useThreadComposerStore, +} = createContextStoreHook(useThreadContext, "useComposer"); + +export const { + useViewport: useThreadViewport, + useViewportStore: useThreadViewportStore, +} = createContextStoreHook(useThreadContext, "useViewport"); diff --git a/packages/react/src/context/react/index.ts b/packages/react/src/context/react/index.ts index e4ec61a13..3bc5155b7 100644 --- a/packages/react/src/context/react/index.ts +++ b/packages/react/src/context/react/index.ts @@ -1,14 +1,85 @@ export { + useAssistantActions, + useAssistantActionsStore, + useAssistantRuntime, + useAssistantRuntimeStore, + useModelConfig, + useModelConfigStore, + useToolUIs, + useToolUIsStore, + + /** + * @deprecated You can import the hooks directly, e.g. `import { useAssistantRuntime } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ type AssistantContextValue, + /** + * @deprecated You can import the hooks directly, e.g. `import { useAssistantRuntime } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ useAssistantContext, } from "./AssistantContext"; -export { type ThreadContextValue, useThreadContext } from "./ThreadContext"; export { - type ComposerContextValue, - useComposerContext, -} from "./ComposerContext"; -export { type MessageContextValue, useMessageContext } from "./MessageContext"; + useThread, + useThreadStore, + useThreadMessages, + useThreadMessagesStore, + useThreadActions, + useThreadActionsStore, + useThreadRuntime, + useThreadRuntimeStore, + useThreadViewport, + useThreadViewportStore, + useThreadComposer, + useThreadComposerStore, + + /** + * @deprecated You can import the hooks directly, e.g. `import { useThread } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + type ThreadContextValue, + /** + * @deprecated You can import the hooks directly, e.g. `import { useThread } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + useThreadContext, +} from "./ThreadContext"; +export { + useMessage, + useMessageStore, + useMessageUtils, + useMessageUtilsStore, + useEditComposer, + useEditComposerStore, + + /** + * @deprecated You can import the hooks directly, e.g. `import { useMessage } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + type MessageContextValue, + /** + * @deprecated You can import the hooks directly, e.g. `import { useMessage } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + useMessageContext, +} from "./MessageContext"; export { + useContentPart, + useContentPartStore, + + /** + * @deprecated You can import the hooks directly, e.g. `import { useContentPart } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ type ContentPartContextValue, + /** + * @deprecated You can import the hooks directly, e.g. `import { useContentPart } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ useContentPartContext, } from "./ContentPartContext"; +export { + useComposer, + useComposerStore, + + /** + * @deprecated You can import the hooks directly, e.g. `import { useComposer } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + type ComposerContextValue, + /** + * @deprecated You can import the hooks directly, e.g. `import { useComposer } from "@assistant-ui/react"`. This will be removed in 0.7.0. + */ + useComposerContext, +} from "./ComposerContext"; diff --git a/packages/react/src/context/react/utils/createContextHook.ts b/packages/react/src/context/react/utils/createContextHook.ts new file mode 100644 index 000000000..40ca2fc93 --- /dev/null +++ b/packages/react/src/context/react/utils/createContextHook.ts @@ -0,0 +1,26 @@ +import { useContext, Context } from "react"; + +/** + * Creates a context hook with optional support. + * @param context - The React context to consume. + * @param providerName - The name of the provider for error messages. + * @returns A hook function that provides the context value. + */ +export function createContextHook( + context: Context, + providerName: string, +) { + function useContextHook(options?: { optional?: false | undefined }): T; + function useContextHook(options?: { + optional?: boolean | undefined; + }): T | null; + function useContextHook(options?: { optional?: boolean | undefined }) { + const contextValue = useContext(context); + if (!options?.optional && !contextValue) { + throw new Error(`This component must be used within ${providerName}.`); + } + return contextValue; + } + + return useContextHook; +} diff --git a/packages/react/src/context/react/utils/createContextStoreHook.ts b/packages/react/src/context/react/utils/createContextStoreHook.ts new file mode 100644 index 000000000..384d22921 --- /dev/null +++ b/packages/react/src/context/react/utils/createContextStoreHook.ts @@ -0,0 +1,79 @@ +import { UseBoundStore } from "zustand"; +import { ReadonlyStore } from "../../ReadonlyStore"; + +/** + * Creates hooks for accessing a store within a context. + * @param contextHook - The hook to access the context. + * @param contextKey - The key of the store in the context. + * @returns An object containing the hooks: `use...` and `use...Store`. + */ +export function createContextStoreHook( + contextHook: (options?: { optional?: boolean }) => T | null, + contextKey: K, +) { + type StoreType = T[K]; + type StateType = StoreType extends ReadonlyStore ? S : never; + + // Define useStoreStoreHook with overloads + function useStoreStoreHook(): ReadonlyStore; + function useStoreStoreHook(options: { + optional: true; + }): ReadonlyStore | null; + function useStoreStoreHook(options?: { + optional?: boolean; + }): ReadonlyStore | null { + const context = contextHook(options); + if (!context) { + if (!options?.optional) { + throw new Error(`This component must be used within a ${contextKey}.`); + } + return null; + } + return context[contextKey] as ReadonlyStore; + } + + // Define useStoreHook with overloads + function useStoreHook(): StateType; + function useStoreHook( + selector: (state: StateType) => TSelected, + ): TSelected; + function useStoreHook(options: { optional: true }): StateType | null; + function useStoreHook(options: { + optional: true; + selector?: (state: StateType) => TSelected; + }): TSelected | null; + function useStoreHook( + param?: + | ((state: StateType) => TSelected) + | { + optional?: boolean; + selector?: (state: StateType) => TSelected; + }, + ): TSelected | StateType | null { + let optional = false; + let selector: ((state: StateType) => TSelected) | undefined; + + if (typeof param === "function") { + selector = param; + } else if (param && typeof param === "object") { + optional = !!param.optional; + selector = param.selector; + } + + const store = useStoreStoreHook({ + optional, + } as any) as UseBoundStore>; + if (!store) return null; + return selector ? store(selector) : store(); + } + + // Return an object with keys based on contextKey + return { + [contextKey]: useStoreHook, + [`${contextKey}Store`]: useStoreStoreHook, + } as { + [P in K]: typeof useStoreHook; + } & { + [P in `${K}Store`]: typeof useStoreStoreHook; + }; +} diff --git a/packages/react/src/context/stores/EditComposer.ts b/packages/react/src/context/stores/EditComposer.ts index ae3a88338..6183fc070 100644 --- a/packages/react/src/context/stores/EditComposer.ts +++ b/packages/react/src/context/stores/EditComposer.ts @@ -1,7 +1,9 @@ -import { create } from "zustand"; +import { create, UseBoundStore } from "zustand"; import { ReadonlyStore } from "../ReadonlyStore"; export type EditComposerState = Readonly<{ + type: "edit"; + // TODO /** @deprecated Use `text` instead. */ value: string; @@ -26,8 +28,10 @@ export const makeEditComposerStore = ({ }: { onEdit: () => string; onSend: (text: string) => void; -}): ReadonlyStore => +}): UseBoundStore> => create()((set, get) => ({ + type: "edit", + get value() { return get().text; }, diff --git a/packages/react/src/context/stores/Composer.ts b/packages/react/src/context/stores/ThreadComposer.ts similarity index 88% rename from packages/react/src/context/stores/Composer.ts rename to packages/react/src/context/stores/ThreadComposer.ts index fd2d3bc17..33269c8b2 100644 --- a/packages/react/src/context/stores/Composer.ts +++ b/packages/react/src/context/stores/ThreadComposer.ts @@ -1,10 +1,12 @@ -import { create } from "zustand"; +import { create, UseBoundStore } from "zustand"; import { ReadonlyStore } from "../ReadonlyStore"; import { Unsubscribe } from "../../types/Unsubscribe"; import { ThreadContextValue } from "../react"; import { ThreadComposerAttachment } from "./Attachment"; -export type ComposerState = Readonly<{ +export type ThreadComposerState = Readonly<{ + type: "thread"; + /** @deprecated Use `text` instead. */ value: string; /** @deprecated Use `setText` instead. */ @@ -29,13 +31,15 @@ export type ComposerState = Readonly<{ onFocus: (listener: () => void) => Unsubscribe; }>; -export const makeComposerStore = ( +export const makeThreadComposerStore = ( useThreadRuntime: ThreadContextValue["useThreadRuntime"], -): ReadonlyStore => { +): UseBoundStore> => { const focusListeners = new Set<() => void>(); - return create()((_, get) => { + return create()((_, get) => { const runtime = useThreadRuntime.getState(); return { + type: "thread", + get value() { return get().text; }, diff --git a/packages/react/src/context/stores/index.ts b/packages/react/src/context/stores/index.ts index 25aed6e9d..877a82fb5 100644 --- a/packages/react/src/context/stores/index.ts +++ b/packages/react/src/context/stores/index.ts @@ -1,7 +1,7 @@ export type { AssistantActionsState } from "./AssistantActions"; export type { AssistantModelConfigState } from "./AssistantModelConfig"; export type { AssistantToolUIsState } from "./AssistantToolUIs"; -export type { ComposerState } from "./Composer"; +export type { ThreadComposerState } from "./ThreadComposer"; export type { ContentPartState } from "./ContentPart"; export type { EditComposerState } from "./EditComposer"; export type { MessageState } from "./Message"; diff --git a/packages/react/src/hooks/useAppendMessage.tsx b/packages/react/src/hooks/useAppendMessage.tsx index 5edc648bd..ff724531b 100644 --- a/packages/react/src/hooks/useAppendMessage.tsx +++ b/packages/react/src/hooks/useAppendMessage.tsx @@ -1,6 +1,13 @@ import { useCallback } from "react"; -import { ThreadContextValue, useThreadContext } from "../context"; +import { + ThreadMessagesState, + useThreadActionsStore, + useThreadMessagesStore, + useThreadViewportStore, +} from "../context"; import { AppendMessage } from "../types"; +import { useThreadComposerStore } from "../context/react/ThreadContext"; +import { ReadonlyStore } from "../context/ReadonlyStore"; type CreateAppendMessage = | string @@ -12,7 +19,7 @@ type CreateAppendMessage = }; const toAppendMessage = ( - useThreadMessages: ThreadContextValue["useThreadMessages"], + useThreadMessages: ReadonlyStore, message: CreateAppendMessage, ): AppendMessage => { if (typeof message === "string") { @@ -34,18 +41,25 @@ const toAppendMessage = ( }; export const useAppendMessage = () => { - const { useThreadMessages, useThreadActions, useViewport, useComposer } = - useThreadContext(); + const threadMessagesStore = useThreadMessagesStore(); + const threadActionsStore = useThreadActionsStore(); + const threadViewportStore = useThreadViewportStore(); + const threadComposerStore = useThreadComposerStore(); const append = useCallback( (message: CreateAppendMessage) => { - const appendMessage = toAppendMessage(useThreadMessages, message); - useThreadActions.getState().append(appendMessage); + const appendMessage = toAppendMessage(threadMessagesStore, message); + threadActionsStore.getState().append(appendMessage); - useViewport.getState().scrollToBottom(); - useComposer.getState().focus(); + threadViewportStore.getState().scrollToBottom(); + threadComposerStore.getState().focus(); }, - [useThreadMessages, useThreadActions, useViewport, useComposer], + [ + threadMessagesStore, + threadActionsStore, + threadViewportStore, + threadComposerStore, + ], ); return append; diff --git a/packages/react/src/hooks/useSwitchToNewThread.tsx b/packages/react/src/hooks/useSwitchToNewThread.tsx index 67f522202..cb9381be1 100644 --- a/packages/react/src/hooks/useSwitchToNewThread.tsx +++ b/packages/react/src/hooks/useSwitchToNewThread.tsx @@ -1,14 +1,14 @@ import { useCallback } from "react"; -import { useAssistantContext, useThreadContext } from "../context"; +import { useAssistantActionsStore } from "../context"; +import { useThreadComposerStore } from "../context/react/ThreadContext"; export const useSwitchToNewThread = () => { - const { useAssistantActions } = useAssistantContext(); - const { useComposer } = useThreadContext(); - + const assistantActionsStore = useAssistantActionsStore(); + const threadComposerStore = useThreadComposerStore(); const switchToNewThread = useCallback(() => { - useAssistantActions.getState().switchToThread(null); - useComposer.getState().focus(); - }, [useAssistantActions, useComposer]); + assistantActionsStore.getState().switchToThread(null); + threadComposerStore.getState().focus(); + }, [assistantActionsStore, threadComposerStore]); return switchToNewThread; }; diff --git a/packages/react/src/model-config/useAssistantInstructions.tsx b/packages/react/src/model-config/useAssistantInstructions.tsx index 171a61186..f5d739b90 100644 --- a/packages/react/src/model-config/useAssistantInstructions.tsx +++ b/packages/react/src/model-config/useAssistantInstructions.tsx @@ -1,17 +1,16 @@ "use client"; import { useEffect } from "react"; -import { useAssistantContext } from "../context/react/AssistantContext"; +import { useModelConfigStore } from "../context"; export const useAssistantInstructions = (instruction: string) => { - const { useModelConfig } = useAssistantContext(); - const registerModelConfigProvider = useModelConfig( - (s) => s.registerModelConfigProvider, - ); + const modelConfigStore = useModelConfigStore(); useEffect(() => { const config = { system: instruction, }; - return registerModelConfigProvider({ getModelConfig: () => config }); - }, [registerModelConfigProvider, instruction]); + return modelConfigStore + .getState() + .registerModelConfigProvider({ getModelConfig: () => config }); + }, [modelConfigStore, instruction]); }; diff --git a/packages/react/src/model-config/useAssistantTool.tsx b/packages/react/src/model-config/useAssistantTool.tsx index e45840b9e..ba38e8ddc 100644 --- a/packages/react/src/model-config/useAssistantTool.tsx +++ b/packages/react/src/model-config/useAssistantTool.tsx @@ -1,7 +1,10 @@ "use client"; import { useEffect } from "react"; -import { useAssistantContext } from "../context/react/AssistantContext"; +import { + useModelConfigStore, + useToolUIsStore, +} from "../context/react/AssistantContext"; import type { ToolCallContentPartComponent } from "../types/ContentPartComponentTypes"; import type { Tool } from "../types/ModelConfigTypes"; @@ -19,11 +22,8 @@ export const useAssistantTool = < >( tool: AssistantToolProps, ) => { - const { useModelConfig, useToolUIs } = useAssistantContext(); - const registerModelConfigProvider = useModelConfig( - (s) => s.registerModelConfigProvider, - ); - const setToolUI = useToolUIs((s) => s.setToolUI); + const modelConfigStore = useModelConfigStore(); + const toolUIsStore = useToolUIsStore(); useEffect(() => { const { toolName, render, ...rest } = tool; const config = { @@ -31,13 +31,15 @@ export const useAssistantTool = < [tool.toolName]: rest, }, }; - const unsub1 = registerModelConfigProvider({ + const unsub1 = modelConfigStore.getState().registerModelConfigProvider({ getModelConfig: () => config, }); - const unsub2 = render ? setToolUI(toolName, render) : undefined; + const unsub2 = render + ? toolUIsStore.getState().setToolUI(toolName, render) + : undefined; return () => { unsub1(); unsub2?.(); }; - }, [registerModelConfigProvider, setToolUI, tool]); + }, [modelConfigStore, toolUIsStore, tool]); }; diff --git a/packages/react/src/model-config/useAssistantToolUI.tsx b/packages/react/src/model-config/useAssistantToolUI.tsx index ffef4e2d4..9906d6199 100644 --- a/packages/react/src/model-config/useAssistantToolUI.tsx +++ b/packages/react/src/model-config/useAssistantToolUI.tsx @@ -1,7 +1,7 @@ "use client"; import { useEffect } from "react"; -import { useAssistantContext } from "../context/react/AssistantContext"; +import { useToolUIsStore } from "../context/react/AssistantContext"; import type { ToolCallContentPartComponent } from "../types/ContentPartComponentTypes"; export type AssistantToolUIProps< @@ -15,11 +15,10 @@ export type AssistantToolUIProps< export const useAssistantToolUI = ( tool: AssistantToolUIProps | null, ) => { - const { useToolUIs } = useAssistantContext(); - const setToolUI = useToolUIs((s) => s.setToolUI); + const toolUIsStore = useToolUIsStore(); useEffect(() => { if (!tool) return; const { toolName, render } = tool; - return setToolUI(toolName, render); - }, [setToolUI, tool]); + return toolUIsStore.getState().setToolUI(toolName, render); + }, [toolUIsStore, tool]); }; diff --git a/packages/react/src/primitive-hooks/actionBar/useActionBarCopy.tsx b/packages/react/src/primitive-hooks/actionBar/useActionBarCopy.tsx index b21e74418..39c548521 100644 --- a/packages/react/src/primitive-hooks/actionBar/useActionBarCopy.tsx +++ b/packages/react/src/primitive-hooks/actionBar/useActionBarCopy.tsx @@ -1,5 +1,9 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { + useEditComposerStore, + useMessageStore, + useMessageUtilsStore, +} from "../../context/react/MessageContext"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; import { getThreadMessageText } from "../../utils/getThreadMessageText"; @@ -10,10 +14,11 @@ export type UseActionBarCopyProps = { export const useActionBarCopy = ({ copiedDuration = 3000, }: UseActionBarCopyProps = {}) => { - const { useMessage, useMessageUtils, useEditComposer } = useMessageContext(); - + const messageStore = useMessageStore(); + const messageUtilsStore = useMessageUtilsStore(); + const editComposerStore = useEditComposerStore(); const hasCopyableContent = useCombinedStore( - [useMessage, useEditComposer], + [messageStore, editComposerStore], ({ message }, c) => { return ( !c.isEditing && @@ -24,9 +29,9 @@ export const useActionBarCopy = ({ ); const callback = useCallback(() => { - const { message } = useMessage.getState(); - const { setIsCopied } = useMessageUtils.getState(); - const { isEditing, text: composerValue } = useEditComposer.getState(); + const { message } = messageStore.getState(); + const { setIsCopied } = messageUtilsStore.getState(); + const { isEditing, text: composerValue } = editComposerStore.getState(); const valueToCopy = isEditing ? composerValue @@ -36,7 +41,7 @@ export const useActionBarCopy = ({ setIsCopied(true); setTimeout(() => setIsCopied(false), copiedDuration); }); - }, [useMessage, useMessageUtils, useEditComposer, copiedDuration]); + }, [messageStore, messageUtilsStore, editComposerStore, copiedDuration]); if (!hasCopyableContent) return null; return callback; diff --git a/packages/react/src/primitive-hooks/actionBar/useActionBarEdit.tsx b/packages/react/src/primitive-hooks/actionBar/useActionBarEdit.tsx index 734915f17..bbef50de6 100644 --- a/packages/react/src/primitive-hooks/actionBar/useActionBarEdit.tsx +++ b/packages/react/src/primitive-hooks/actionBar/useActionBarEdit.tsx @@ -1,13 +1,15 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { + useEditComposer, + useEditComposerStore, +} from "../../context/react/MessageContext"; export const useActionBarEdit = () => { - const { useEditComposer } = useMessageContext(); - + const editComposerStore = useEditComposerStore(); const disabled = useEditComposer((c) => c.isEditing); const callback = useCallback(() => { - const { edit } = useEditComposer.getState(); + const { edit } = editComposerStore.getState(); edit(); }, [useEditComposer]); diff --git a/packages/react/src/primitive-hooks/actionBar/useActionBarReload.tsx b/packages/react/src/primitive-hooks/actionBar/useActionBarReload.tsx index f6580261b..604857eb8 100644 --- a/packages/react/src/primitive-hooks/actionBar/useActionBarReload.tsx +++ b/packages/react/src/primitive-hooks/actionBar/useActionBarReload.tsx @@ -1,24 +1,36 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useMessageStore } from "../../context/react/MessageContext"; +import { + useThreadActionsStore, + useThreadComposerStore, + useThreadStore, + useThreadViewportStore, +} from "../../context/react/ThreadContext"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; export const useActionBarReload = () => { - const { useThread, useThreadActions, useComposer, useViewport } = - useThreadContext(); - const { useMessage } = useMessageContext(); + const messageStore = useMessageStore(); + const threadStore = useThreadStore(); + const threadActionsStore = useThreadActionsStore(); + const threadComposerStore = useThreadComposerStore(); + const threadViewportStore = useThreadViewportStore(); const disabled = useCombinedStore( - [useThread, useMessage], + [threadStore, messageStore], (t, m) => t.isRunning || t.isDisabled || m.message.role !== "assistant", ); const callback = useCallback(() => { - const { parentId } = useMessage.getState(); - useThreadActions.getState().startRun(parentId); - useViewport.getState().scrollToBottom(); - useComposer.getState().focus(); - }, [useThreadActions, useComposer, useViewport, useMessage]); + const { parentId } = messageStore.getState(); + threadActionsStore.getState().startRun(parentId); + threadViewportStore.getState().scrollToBottom(); + threadComposerStore.getState().focus(); + }, [ + threadActionsStore, + threadComposerStore, + threadViewportStore, + messageStore, + ]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/actionBar/useActionBarSpeak.tsx b/packages/react/src/primitive-hooks/actionBar/useActionBarSpeak.tsx index f07aa3f25..c414641ca 100644 --- a/packages/react/src/primitive-hooks/actionBar/useActionBarSpeak.tsx +++ b/packages/react/src/primitive-hooks/actionBar/useActionBarSpeak.tsx @@ -1,14 +1,21 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; -import { useThreadContext } from "../../context/react/ThreadContext"; + import { useCombinedStore } from "../../utils/combined/useCombinedStore"; +import { + useEditComposerStore, + useMessageStore, + useMessageUtilsStore, + useThreadActionsStore, +} from "../../context"; export const useActionBarSpeak = () => { - const { useThreadActions } = useThreadContext(); - const { useMessage, useEditComposer, useMessageUtils } = useMessageContext(); + const messageStore = useMessageStore(); + const editComposerStore = useEditComposerStore(); + const threadActionsStore = useThreadActionsStore(); + const messageUtilsStore = useMessageUtilsStore(); const hasSpeakableContent = useCombinedStore( - [useMessage, useEditComposer], + [messageStore, editComposerStore], ({ message }, c) => { return ( !c.isEditing && @@ -19,10 +26,10 @@ export const useActionBarSpeak = () => { ); const callback = useCallback(async () => { - const { message } = useMessage.getState(); - const utt = useThreadActions.getState().speak(message.id); - useMessageUtils.getState().addUtterance(utt); - }, [useThreadActions, useMessage, useMessageUtils]); + const { message } = messageStore.getState(); + const utt = threadActionsStore.getState().speak(message.id); + messageUtilsStore.getState().addUtterance(utt); + }, [threadActionsStore, messageStore, messageUtilsStore]); if (!hasSpeakableContent) return null; return callback; diff --git a/packages/react/src/primitive-hooks/actionBar/useActionBarStopSpeaking.tsx b/packages/react/src/primitive-hooks/actionBar/useActionBarStopSpeaking.tsx index baccd56f0..3dac96193 100644 --- a/packages/react/src/primitive-hooks/actionBar/useActionBarStopSpeaking.tsx +++ b/packages/react/src/primitive-hooks/actionBar/useActionBarStopSpeaking.tsx @@ -1,14 +1,16 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { + useMessageUtils, + useMessageUtilsStore, +} from "../../context/react/MessageContext"; export const useActionBarStopSpeaking = () => { - const { useMessageUtils } = useMessageContext(); - + const messageUtilsStore = useMessageUtilsStore(); const isSpeaking = useMessageUtils((u) => u.isSpeaking); const callback = useCallback(async () => { - useMessageUtils.getState().stopSpeaking(); - }, [useMessageUtils]); + messageUtilsStore.getState().stopSpeaking(); + }, [messageUtilsStore]); if (!isSpeaking) return null; diff --git a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerCount.tsx b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerCount.tsx index 5c40c6861..9fbc57d39 100644 --- a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerCount.tsx +++ b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerCount.tsx @@ -1,8 +1,7 @@ "use client"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { useMessage } from "../../context/react/MessageContext"; export const useBranchPickerCount = () => { - const { useMessage } = useMessageContext(); const branchCount = useMessage((s) => s.branches.length); return branchCount; }; diff --git a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNext.tsx b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNext.tsx index 5132195a4..ba67c7f0d 100644 --- a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNext.tsx +++ b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNext.tsx @@ -1,24 +1,27 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; -import { useThreadContext } from "../../context/react/ThreadContext"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; +import { + useEditComposerStore, + useMessageStore, + useThreadActionsStore, +} from "../../context"; export const useBranchPickerNext = () => { - const { useThreadActions } = useThreadContext(); - const { useMessage, useEditComposer } = useMessageContext(); - + const messageStore = useMessageStore(); + const editComposerStore = useEditComposerStore(); + const threadActionsStore = useThreadActionsStore(); const disabled = useCombinedStore( - [useMessage, useEditComposer], + [messageStore, editComposerStore], (m, c) => c.isEditing || m.branches.indexOf(m.message.id) + 1 >= m.branches.length, ); const callback = useCallback(() => { - const { message, branches } = useMessage.getState(); - useThreadActions + const { message, branches } = messageStore.getState(); + threadActionsStore .getState() .switchToBranch(branches[branches.indexOf(message.id) + 1]!); - }, [useThreadActions, useMessage]); + }, [threadActionsStore, messageStore]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNumber.tsx b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNumber.tsx index cd1e37c77..b435b3ec0 100644 --- a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNumber.tsx +++ b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerNumber.tsx @@ -1,8 +1,8 @@ "use client"; -import { useMessageContext } from "../../context/react/MessageContext"; + +import { useMessage } from "../../context/react/MessageContext"; export const useBranchPickerNumber = () => { - const { useMessage } = useMessageContext(); const branchIdx = useMessage((s) => s.branches.indexOf(s.message.id)); return branchIdx + 1; }; diff --git a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerPrevious.tsx b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerPrevious.tsx index a1e75e169..f07667d15 100644 --- a/packages/react/src/primitive-hooks/branchPicker/useBranchPickerPrevious.tsx +++ b/packages/react/src/primitive-hooks/branchPicker/useBranchPickerPrevious.tsx @@ -1,23 +1,27 @@ import { useCallback } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { + useEditComposerStore, + useMessageStore, +} from "../../context/react/MessageContext"; +import { useThreadActionsStore } from "../../context/react/ThreadContext"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; export const useBranchPickerPrevious = () => { - const { useThreadActions } = useThreadContext(); - const { useMessage, useEditComposer } = useMessageContext(); + const messageStore = useMessageStore(); + const editComposerStore = useEditComposerStore(); + const threadActionsStore = useThreadActionsStore(); const disabled = useCombinedStore( - [useMessage, useEditComposer], + [messageStore, editComposerStore], (m, c) => c.isEditing || m.branches.indexOf(m.message.id) <= 0, ); const callback = useCallback(() => { - const { message, branches } = useMessage.getState(); - useThreadActions + const { message, branches } = messageStore.getState(); + threadActionsStore .getState() .switchToBranch(branches[branches.indexOf(message.id) - 1]!); - }, [useThreadActions, useMessage]); + }, [threadActionsStore, messageStore]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx b/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx index 1ae738da8..327803601 100644 --- a/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx +++ b/packages/react/src/primitive-hooks/composer/useComposerAddAttachment.tsx @@ -1,14 +1,15 @@ import { useCallback } from "react"; -import { useThreadContext } from "../../context"; +import { useComposer, useThreadRuntimeStore } from "../../context"; +import { useThreadComposerStore } from "../../context/react/ThreadContext"; export const useComposerAddAttachment = () => { - const { useComposer, useThreadRuntime } = useThreadContext(); - const disabled = useComposer((c) => !c.isEditing); + const threadComposerStore = useThreadComposerStore(); + const threadRuntimeStore = useThreadRuntimeStore(); const callback = useCallback(() => { - const { addAttachment } = useComposer.getState(); - const { attachmentAccept } = useThreadRuntime.getState().composer; + const { addAttachment } = threadComposerStore.getState(); + const { attachmentAccept } = threadRuntimeStore.getState().composer; const input = document.createElement("input"); input.type = "file"; @@ -23,7 +24,7 @@ export const useComposerAddAttachment = () => { }; input.click(); - }, [useComposer, useThreadRuntime]); + }, [threadComposerStore, threadRuntimeStore]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/composer/useComposerCancel.tsx b/packages/react/src/primitive-hooks/composer/useComposerCancel.tsx index ff3da3113..8dcc731ae 100644 --- a/packages/react/src/primitive-hooks/composer/useComposerCancel.tsx +++ b/packages/react/src/primitive-hooks/composer/useComposerCancel.tsx @@ -1,15 +1,14 @@ import { useCallback } from "react"; -import { useComposerContext } from "../../context"; +import { useComposer, useComposerStore } from "../../context"; export const useComposerCancel = () => { - const { useComposer } = useComposerContext(); - + const composerStore = useComposerStore(); const disabled = useComposer((c) => !c.canCancel); const callback = useCallback(() => { - const { cancel } = useComposer.getState(); + const { cancel } = composerStore.getState(); cancel(); - }, [useComposer]); + }, [composerStore]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/composer/useComposerIf.tsx b/packages/react/src/primitive-hooks/composer/useComposerIf.tsx index b82781247..442d500fa 100644 --- a/packages/react/src/primitive-hooks/composer/useComposerIf.tsx +++ b/packages/react/src/primitive-hooks/composer/useComposerIf.tsx @@ -1,5 +1,6 @@ "use client"; -import { useComposerContext } from "../../context/react/ComposerContext"; + +import { useComposer } from "../../context/react/ComposerContext"; import type { RequireAtLeastOne } from "../../utils/RequireAtLeastOne"; type ComposerIfFilters = { @@ -9,7 +10,6 @@ type ComposerIfFilters = { export type UseComposerIfProps = RequireAtLeastOne; export const useComposerIf = (props: UseComposerIfProps) => { - const { useComposer } = useComposerContext(); return useComposer((composer) => { if (props.editing === true && !composer.isEditing) return false; if (props.editing === false && composer.isEditing) return false; diff --git a/packages/react/src/primitive-hooks/composer/useComposerSend.tsx b/packages/react/src/primitive-hooks/composer/useComposerSend.tsx index 8d7730685..e05482ae8 100644 --- a/packages/react/src/primitive-hooks/composer/useComposerSend.tsx +++ b/packages/react/src/primitive-hooks/composer/useComposerSend.tsx @@ -1,29 +1,32 @@ import { useCallback } from "react"; -import { useComposerContext, useThreadContext } from "../../context"; +import { useComposerStore } from "../../context"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; +import { + useThreadComposerStore, + useThreadStore, + useThreadViewportStore, +} from "../../context/react/ThreadContext"; export const useComposerSend = () => { - const { - useThread, - useViewport, - useComposer: useNewComposer, - } = useThreadContext(); - const { useComposer } = useComposerContext(); + const threadStore = useThreadStore(); + const threadViewportStore = useThreadViewportStore(); + const composerStore = useComposerStore(); + const threadComposerStore = useThreadComposerStore(); const disabled = useCombinedStore( - [useThread, useComposer], + [threadStore, composerStore], (t, c) => t.isRunning || !c.isEditing || c.isEmpty, ); const callback = useCallback(() => { - const composerState = useComposer.getState(); + const composerState = composerStore.getState(); if (!composerState.isEditing) return; composerState.send(); - useViewport.getState().scrollToBottom(); - useNewComposer.getState().focus(); - }, [useNewComposer, useComposer, useViewport]); + threadViewportStore.getState().scrollToBottom(); + threadComposerStore.getState().focus(); + }, [threadComposerStore, composerStore, threadViewportStore]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/contentPart/useContentPartDisplay.tsx b/packages/react/src/primitive-hooks/contentPart/useContentPartDisplay.tsx index 39161dd11..00f2a58d5 100644 --- a/packages/react/src/primitive-hooks/contentPart/useContentPartDisplay.tsx +++ b/packages/react/src/primitive-hooks/contentPart/useContentPartDisplay.tsx @@ -1,9 +1,7 @@ -import { useContentPartContext } from "../../context/react/ContentPartContext"; +import { useContentPart } from "../../context/react/ContentPartContext"; import { UIContentPartState } from "../../context/stores/ContentPart"; export const useContentPartDisplay = () => { - const { useContentPart } = useContentPartContext(); - const display = useContentPart((c) => { if (c.part.type !== "ui") throw new Error( diff --git a/packages/react/src/primitive-hooks/contentPart/useContentPartImage.tsx b/packages/react/src/primitive-hooks/contentPart/useContentPartImage.tsx index 24bfee471..39ad2395e 100644 --- a/packages/react/src/primitive-hooks/contentPart/useContentPartImage.tsx +++ b/packages/react/src/primitive-hooks/contentPart/useContentPartImage.tsx @@ -1,9 +1,7 @@ -import { useContentPartContext } from "../../context/react/ContentPartContext"; +import { useContentPart } from "../../context/react/ContentPartContext"; import { ImageContentPartState } from "../../context/stores/ContentPart"; export const useContentPartImage = () => { - const { useContentPart } = useContentPartContext(); - const image = useContentPart((c) => { if (c.part.type !== "image") throw new Error( diff --git a/packages/react/src/primitive-hooks/contentPart/useContentPartText.tsx b/packages/react/src/primitive-hooks/contentPart/useContentPartText.tsx index a310488cf..b06e265b4 100644 --- a/packages/react/src/primitive-hooks/contentPart/useContentPartText.tsx +++ b/packages/react/src/primitive-hooks/contentPart/useContentPartText.tsx @@ -1,9 +1,7 @@ -import { useContentPartContext } from "../../context/react/ContentPartContext"; +import { useContentPart } from "../../context/react/ContentPartContext"; import { TextContentPartState } from "../../context/stores/ContentPart"; export const useContentPartText = () => { - const { useContentPart } = useContentPartContext(); - const text = useContentPart((c) => { if (c.part.type !== "text") throw new Error( diff --git a/packages/react/src/primitive-hooks/message/useMessageIf.tsx b/packages/react/src/primitive-hooks/message/useMessageIf.tsx index 39206e257..254e50742 100644 --- a/packages/react/src/primitive-hooks/message/useMessageIf.tsx +++ b/packages/react/src/primitive-hooks/message/useMessageIf.tsx @@ -1,5 +1,8 @@ "use client"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { + useMessageStore, + useMessageUtilsStore, +} from "../../context/react/MessageContext"; import type { RequireAtLeastOne } from "../../utils/RequireAtLeastOne"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; @@ -16,10 +19,11 @@ type MessageIfFilters = { export type UseMessageIfProps = RequireAtLeastOne; export const useMessageIf = (props: UseMessageIfProps) => { - const { useMessage, useMessageUtils } = useMessageContext(); + const messageStore = useMessageStore(); + const messageUtilsStore = useMessageUtilsStore(); return useCombinedStore( - [useMessage, useMessageUtils], + [messageStore, messageUtilsStore], ({ message, branches, isLast }, { isCopied, isHovering, isSpeaking }) => { if (props.hasBranches === true && branches.length < 2) return false; diff --git a/packages/react/src/primitive-hooks/thread/useThreadIf.tsx b/packages/react/src/primitive-hooks/thread/useThreadIf.tsx index 988707932..5c6b00f18 100644 --- a/packages/react/src/primitive-hooks/thread/useThreadIf.tsx +++ b/packages/react/src/primitive-hooks/thread/useThreadIf.tsx @@ -1,6 +1,6 @@ "use client"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useThreadMessagesStore, useThreadStore } from "../../context"; import type { RequireAtLeastOne } from "../../utils/RequireAtLeastOne"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; @@ -13,9 +13,10 @@ type ThreadIfFilters = { export type UseThreadIfProps = RequireAtLeastOne; export const useThreadIf = (props: UseThreadIfProps) => { - const { useThread, useThreadMessages } = useThreadContext(); + const threadStore = useThreadStore(); + const threadMessagesStore = useThreadMessagesStore(); return useCombinedStore( - [useThread, useThreadMessages], + [threadStore, threadMessagesStore], (thread, messages) => { if (props.empty === true && messages.length !== 0) return false; if (props.empty === false && messages.length === 0) return false; diff --git a/packages/react/src/primitive-hooks/thread/useThreadScrollToBottom.tsx b/packages/react/src/primitive-hooks/thread/useThreadScrollToBottom.tsx index afbc1714e..2197ea8c4 100644 --- a/packages/react/src/primitive-hooks/thread/useThreadScrollToBottom.tsx +++ b/packages/react/src/primitive-hooks/thread/useThreadScrollToBottom.tsx @@ -1,15 +1,20 @@ import { useCallback } from "react"; -import { useThreadContext } from "../../context"; +import { useThreadViewport } from "../../context"; +import { + useThreadComposerStore, + useThreadViewportStore, +} from "../../context/react/ThreadContext"; export const useThreadScrollToBottom = () => { - const { useComposer, useViewport } = useThreadContext(); + const isAtBottom = useThreadViewport((s) => s.isAtBottom); - const isAtBottom = useViewport((s) => s.isAtBottom); + const threadViewportStore = useThreadViewportStore(); + const threadComposerStore = useThreadComposerStore(); const handleScrollToBottom = useCallback(() => { - useViewport.getState().scrollToBottom(); - useComposer.getState().focus(); - }, [useViewport, useComposer]); + threadViewportStore.getState().scrollToBottom(); + threadComposerStore.getState().focus(); + }, [threadViewportStore, threadComposerStore]); if (isAtBottom) return null; return handleScrollToBottom; diff --git a/packages/react/src/primitive-hooks/thread/useThreadSuggestion.tsx b/packages/react/src/primitive-hooks/thread/useThreadSuggestion.tsx index 46856e989..fde61526b 100644 --- a/packages/react/src/primitive-hooks/thread/useThreadSuggestion.tsx +++ b/packages/react/src/primitive-hooks/thread/useThreadSuggestion.tsx @@ -1,6 +1,7 @@ import { useCallback } from "react"; -import { useThreadContext } from "../../context"; +import { useThread, useThreadStore } from "../../context"; import { useAppendMessage } from "../../hooks"; +import { useThreadComposerStore } from "../../context/react/ThreadContext"; export type UseApplyThreadSuggestionProps = { prompt: string; @@ -12,20 +13,21 @@ export const useThreadSuggestion = ({ prompt, autoSend, }: UseApplyThreadSuggestionProps) => { - const { useThread, useComposer } = useThreadContext(); + const threadStore = useThreadStore(); + const composerStore = useThreadComposerStore(); const append = useAppendMessage(); const disabled = useThread((t) => t.isDisabled); const callback = useCallback(() => { - const thread = useThread.getState(); - const composer = useComposer.getState(); + const thread = threadStore.getState(); + const composer = composerStore.getState(); if (autoSend && !thread.isRunning) { append(prompt); composer.setText(""); } else { composer.setText(prompt); } - }, [useThread, useComposer, autoSend, append, prompt]); + }, [threadStore, composerStore, autoSend, append, prompt]); if (disabled) return null; return callback; diff --git a/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx index 4379797ca..fce6f94e1 100644 --- a/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx +++ b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx @@ -1,7 +1,7 @@ "use client"; import { useComposedRefs } from "@radix-ui/react-compose-refs"; import { useRef } from "react"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useThreadViewportStore } from "../../context/react/ThreadContext"; import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent"; import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom"; import { useManagedRef } from "../../utils/hooks/useManagedRef"; @@ -16,7 +16,7 @@ export const useThreadViewportAutoScroll = ({ }: UseThreadViewportAutoScrollProps) => { const divRef = useRef(null); - const { useViewport } = useThreadContext(); + const threadViewportStore = useThreadViewportStore(); const lastScrollTop = useRef(0); @@ -36,7 +36,7 @@ export const useThreadViewportAutoScroll = ({ const div = divRef.current; if (!div) return; - const isAtBottom = useViewport.getState().isAtBottom; + const isAtBottom = threadViewportStore.getState().isAtBottom; const newIsAtBottom = div.scrollHeight - div.scrollTop <= div.clientHeight + 1; // TODO figure out why +1 is needed @@ -48,7 +48,9 @@ export const useThreadViewportAutoScroll = ({ } if (newIsAtBottom !== isAtBottom) { - writableStore(useViewport).setState({ isAtBottom: newIsAtBottom }); + writableStore(threadViewportStore).setState({ + isAtBottom: newIsAtBottom, + }); } } @@ -56,7 +58,10 @@ export const useThreadViewportAutoScroll = ({ }; const resizeRef = useOnResizeContent(() => { - if (isScrollingToBottomRef.current || useViewport.getState().isAtBottom) { + if ( + isScrollingToBottomRef.current || + threadViewportStore.getState().isAtBottom + ) { scrollToBottom("instant"); } diff --git a/packages/react/src/primitives/actionBar/useActionBarFloatStatus.tsx b/packages/react/src/primitives/actionBar/useActionBarFloatStatus.tsx index cbff25b10..8f8cbe9fa 100644 --- a/packages/react/src/primitives/actionBar/useActionBarFloatStatus.tsx +++ b/packages/react/src/primitives/actionBar/useActionBarFloatStatus.tsx @@ -1,6 +1,9 @@ "use client"; -import { useMessageContext } from "../../context/react/MessageContext"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { + useMessageStore, + useMessageUtilsStore, +} from "../../context/react/MessageContext"; +import { useThreadStore } from "../../context/react/ThreadContext"; import { useCombinedStore } from "../../utils/combined/useCombinedStore"; export enum HideAndFloatStatus { @@ -20,11 +23,12 @@ export const useActionBarFloatStatus = ({ autohide, autohideFloat, }: UseActionBarFloatStatusProps) => { - const { useThread } = useThreadContext(); - const { useMessage, useMessageUtils } = useMessageContext(); + const threadStore = useThreadStore(); + const messageStore = useMessageStore(); + const messageUtilsStore = useMessageUtilsStore(); return useCombinedStore( - [useThread, useMessage, useMessageUtils], + [threadStore, messageStore, messageUtilsStore], (t, m, mu) => { if (hideWhenRunning && t.isRunning) return HideAndFloatStatus.Hidden; diff --git a/packages/react/src/primitives/composer/ComposerAttachments.tsx b/packages/react/src/primitives/composer/ComposerAttachments.tsx index 243f8334d..41d4365df 100644 --- a/packages/react/src/primitives/composer/ComposerAttachments.tsx +++ b/packages/react/src/primitives/composer/ComposerAttachments.tsx @@ -1,10 +1,10 @@ "use client"; import { ComponentType, type FC, memo } from "react"; -import { useThreadContext } from "../../context"; -import { useAttachmentContext } from "../../context/react/AttachmentContext"; +import { useComposerAttachment } from "../../context/react/AttachmentContext"; import { ComposerAttachmentProvider } from "../../context/providers/ComposerAttachmentProvider"; import type { ThreadComposerAttachment } from "../../context/stores/Attachment"; +import { useThreadComposer } from "../../context/react/ThreadContext"; export type ComposerPrimitiveAttachmentsProps = { components: @@ -38,8 +38,7 @@ const getComponent = ( const AttachmentComponent: FC<{ components: ComposerPrimitiveAttachmentsProps["components"]; }> = ({ components }) => { - const { useAttachment } = useAttachmentContext({ type: "composer" }); - const Component = useAttachment((a) => + const Component = useComposerAttachment((a) => getComponent(components, a.attachment), ); @@ -70,8 +69,7 @@ const ComposerAttachment = memo( export const ComposerPrimitiveAttachments: FC< ComposerPrimitiveAttachmentsProps > = ({ components }) => { - const { useComposer } = useThreadContext(); - const attachmentsCount = useComposer((s) => s.attachments.length); + const attachmentsCount = useThreadComposer((s) => s.attachments.length); return Array.from({ length: attachmentsCount }, (_, index) => ( { - const { useThread } = useThreadContext(); - const { useComposer, type } = useComposerContext(); + const threadStore = useThreadStore(); + const composerStore = useComposerStore(); const value = useComposer((c) => { if (!c.isEditing) return ""; @@ -52,7 +55,7 @@ export const ComposerPrimitiveInput = forwardRef< const ref = useComposedRefs(forwardedRef, textareaRef); useEscapeKeydown((e) => { - const composer = useComposer.getState(); + const composer = composerStore.getState(); if (composer.canCancel) { composer.cancel(); e.preventDefault(); @@ -66,7 +69,7 @@ export const ComposerPrimitiveInput = forwardRef< if (e.nativeEvent.isComposing) return; if (e.key === "Enter" && e.shiftKey === false) { - const { isRunning } = useThread.getState(); + const { isRunning } = threadStore.getState(); if (!isRunning) { e.preventDefault(); @@ -91,7 +94,7 @@ export const ComposerPrimitiveInput = forwardRef< useEffect(() => focus(), [focus]); useOnComposerFocus(() => { - if (type === "new") { + if (composerStore.getState().type === "thread") { focus(); } }); @@ -104,7 +107,7 @@ export const ComposerPrimitiveInput = forwardRef< ref={ref} disabled={isDisabled} onChange={composeEventHandlers(onChange, (e) => { - const composerState = useComposer.getState(); + const composerState = composerStore.getState(); if (!composerState.isEditing) return; return composerState.setText(e.target.value); })} diff --git a/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx b/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx index 08be84df0..63f5c5310 100644 --- a/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx +++ b/packages/react/src/primitives/contentPart/ContentPartInProgress.tsx @@ -1,5 +1,5 @@ import { FC, PropsWithChildren } from "react"; -import { useContentPartContext } from "../../context"; +import { useContentPart } from "../../context"; export type ContentPartPrimitiveInProgressProps = PropsWithChildren; @@ -7,7 +7,6 @@ export type ContentPartPrimitiveInProgressProps = PropsWithChildren; export const ContentPartPrimitiveInProgress: FC< ContentPartPrimitiveInProgressProps > = ({ children }) => { - const { useContentPart } = useContentPartContext(); const isInProgress = useContentPart((c) => c.status.type === "running"); return isInProgress ? children : null; diff --git a/packages/react/src/primitives/message/MessageAttachments.tsx b/packages/react/src/primitives/message/MessageAttachments.tsx index efd68f899..1bff8cfdb 100644 --- a/packages/react/src/primitives/message/MessageAttachments.tsx +++ b/packages/react/src/primitives/message/MessageAttachments.tsx @@ -1,8 +1,8 @@ "use client"; import { ComponentType, type FC, memo } from "react"; -import { useMessageContext } from "../../context"; -import { useAttachmentContext } from "../../context/react/AttachmentContext"; +import { useMessage } from "../../context"; +import { useMessageAttachment } from "../../context/react/AttachmentContext"; import { MessageAttachmentProvider } from "../../context/providers/MessageAttachmentProvider"; import type { MessageAttachment } from "../../context/stores/Attachment"; @@ -38,8 +38,7 @@ const getComponent = ( const AttachmentComponent: FC<{ components: MessagePrimitiveAttachmentsProps["components"]; }> = ({ components }) => { - const { useAttachment } = useAttachmentContext({ type: "message" }); - const Component = useAttachment((a) => + const Component = useMessageAttachment((a) => getComponent(components, a.attachment), ); @@ -70,7 +69,6 @@ const MessageAttachment = memo( export const MessagePrimitiveAttachments: FC< MessagePrimitiveAttachmentsProps > = ({ components }) => { - const { useMessage } = useMessageContext(); const attachmentsCount = useMessage(({ message }) => { if (message.role !== "user") return 0; return message.attachments.length; diff --git a/packages/react/src/primitives/message/MessageContent.tsx b/packages/react/src/primitives/message/MessageContent.tsx index 52406d4d6..e708e3c17 100644 --- a/packages/react/src/primitives/message/MessageContent.tsx +++ b/packages/react/src/primitives/message/MessageContent.tsx @@ -2,11 +2,14 @@ import { type ComponentType, type FC, memo } from "react"; import { - useAssistantContext, - useContentPartContext, - useThreadContext, + useContentPart, + useThreadActionsStore, + useToolUIs, } from "../../context"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { + useMessage, + useMessageStore, +} from "../../context/react/MessageContext"; import { ContentPartProvider, EMPTY_CONTENT, @@ -49,7 +52,6 @@ const ToolUIDisplay = ({ }: { UI: ToolCallContentPartComponent | undefined; } & ToolCallContentPartProps) => { - const { useToolUIs } = useAssistantContext(); const Render = useToolUIs((s) => s.getToolUI(props.part.toolName)) ?? UI; if (!Render) return null; return ; @@ -81,11 +83,9 @@ const MessageContentPartComponent: FC = ({ tools: { by_name = {}, Fallback = undefined } = {}, } = {}, }) => { - const { useThreadActions } = useThreadContext(); - const { useMessage } = useMessageContext(); - const addToolResult = useThreadActions((t) => t.addToolResult); + const messageStore = useMessageStore(); + const threadActionsStore = useThreadActionsStore(); - const { useContentPart } = useContentPartContext(); const { part, status } = useContentPart(); const type = part.type; @@ -113,8 +113,8 @@ const MessageContentPartComponent: FC = ({ case "tool-call": { const Tool = by_name[part.toolName] || Fallback; const addResult = (result: any) => - addToolResult({ - messageId: useMessage.getState().message.id, + threadActionsStore.getState().addToolResult({ + messageId: messageStore.getState().message.id, toolName: part.toolName, toolCallId: part.toolCallId, result, @@ -163,8 +163,6 @@ const MessageContentPart = memo( export const MessagePrimitiveContent: FC = ({ components, }) => { - const { useMessage } = useMessageContext(); - const contentLength = useMessage((s) => s.message.content.length) || 1; return Array.from({ length: contentLength }, (_, index) => ( diff --git a/packages/react/src/primitives/message/MessageInProgress.tsx b/packages/react/src/primitives/message/MessageInProgress.tsx index f1e33e47d..a73fa0232 100644 --- a/packages/react/src/primitives/message/MessageInProgress.tsx +++ b/packages/react/src/primitives/message/MessageInProgress.tsx @@ -8,7 +8,7 @@ type PrimitiveSpanProps = ComponentPropsWithoutRef; export type MessagePrimitiveInProgressProps = PrimitiveSpanProps; /** - * @deprecated Define a custom Text renderer via ContentPartPrimitiveInProgress instead. + * @deprecated Define a custom Text renderer via ContentPartPrimitiveInProgress instead. This will be removed in 0.6. */ export const MessagePrimitiveInProgress: FC< MessagePrimitiveInProgressProps diff --git a/packages/react/src/primitives/message/MessageRoot.tsx b/packages/react/src/primitives/message/MessageRoot.tsx index a4f73529f..a1473e779 100644 --- a/packages/react/src/primitives/message/MessageRoot.tsx +++ b/packages/react/src/primitives/message/MessageRoot.tsx @@ -7,7 +7,7 @@ import { ComponentPropsWithoutRef, useCallback, } from "react"; -import { useMessageContext } from "../../context/react/MessageContext"; +import { useMessageUtilsStore } from "../../context/react/MessageContext"; import { useManagedRef } from "../../utils/hooks/useManagedRef"; import { useComposedRefs } from "@radix-ui/react-compose-refs"; @@ -15,11 +15,10 @@ type MessagePrimitiveRootElement = ElementRef; type PrimitiveDivProps = ComponentPropsWithoutRef; const useIsHoveringRef = () => { - const { useMessageUtils } = useMessageContext(); - + const messageUtilsStore = useMessageUtilsStore(); const callbackRef = useCallback( (el: HTMLElement) => { - const setIsHovering = useMessageUtils.getState().setIsHovering; + const setIsHovering = messageUtilsStore.getState().setIsHovering; const handleMouseEnter = () => { setIsHovering(true); @@ -37,7 +36,7 @@ const useIsHoveringRef = () => { setIsHovering(false); }; }, - [useMessageUtils], + [messageUtilsStore], ); return useManagedRef(callbackRef); diff --git a/packages/react/src/primitives/thread/ThreadMessages.tsx b/packages/react/src/primitives/thread/ThreadMessages.tsx index 06f14c778..bc7838e19 100644 --- a/packages/react/src/primitives/thread/ThreadMessages.tsx +++ b/packages/react/src/primitives/thread/ThreadMessages.tsx @@ -1,9 +1,9 @@ "use client"; import { type ComponentType, type FC, memo } from "react"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useThreadMessages } from "../../context/react/ThreadContext"; import { MessageProvider } from "../../context/providers/MessageProvider"; -import { useMessageContext } from "../../context"; +import { useEditComposer, useMessage } from "../../context"; import { ThreadMessage as ThreadMessageType } from "../../types"; export type ThreadPrimitiveMessagesProps = { @@ -102,7 +102,6 @@ type ThreadMessageComponentProps = { const ThreadMessageComponent: FC = ({ components, }) => { - const { useMessage, useEditComposer } = useMessageContext(); const role = useMessage((m) => m.message.role); const isEditing = useEditComposer((c) => c.isEditing); const Component = getComponent(components, role, isEditing); @@ -136,8 +135,6 @@ const ThreadMessage = memo( export const ThreadPrimitiveMessagesImpl: FC = ({ components, }) => { - const { useThreadMessages } = useThreadContext(); - const messagesLength = useThreadMessages((t) => t.length); if (messagesLength === 0) return null; diff --git a/packages/react/src/styles/tailwindcss/thread.css b/packages/react/src/styles/tailwindcss/thread.css index fb4546dfb..54294cd01 100644 --- a/packages/react/src/styles/tailwindcss/thread.css +++ b/packages/react/src/styles/tailwindcss/thread.css @@ -42,7 +42,9 @@ @apply line-clamp-2 text-ellipsis text-sm font-semibold; } -/* composer */ +/* TODO rename classes to .aui-thread-composer-root ? */ +/* rename composer to thread composer everywhere */ +/* thread composer */ .aui-composer-root { @apply focus-within:border-aui-ring/20 flex w-full flex-wrap items-end rounded-lg border px-2.5 shadow-sm transition-colors ease-in; diff --git a/packages/react/src/ui/assistant-action-bar.tsx b/packages/react/src/ui/assistant-action-bar.tsx index 1b518cc60..51f147f24 100644 --- a/packages/react/src/ui/assistant-action-bar.tsx +++ b/packages/react/src/ui/assistant-action-bar.tsx @@ -15,25 +15,22 @@ import { } from "./base/tooltip-icon-button"; import { withDefaults } from "./utils/withDefaults"; import { useThreadConfig } from "./thread-config"; -import { useThreadContext } from "../context"; +import { useThread } from "../context"; const useAllowCopy = (ensureCapability = false) => { const { assistantMessage: { allowCopy = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const copySupported = useThread((t) => t.capabilities.unstable_copy); return allowCopy && (!ensureCapability || copySupported); }; const useAllowSpeak = (ensureCapability = false) => { const { assistantMessage: { allowSpeak = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const speakSupported = useThread((t) => t.capabilities.speak); return allowSpeak && (!ensureCapability || speakSupported); }; const useAllowReload = (ensureCapability = false) => { const { assistantMessage: { allowReload = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const reloadSupported = useThread((t) => t.capabilities.reload); return allowReload && (!ensureCapability || reloadSupported); }; diff --git a/packages/react/src/ui/branch-picker.tsx b/packages/react/src/ui/branch-picker.tsx index b6bb466b3..82d0a8f6c 100644 --- a/packages/react/src/ui/branch-picker.tsx +++ b/packages/react/src/ui/branch-picker.tsx @@ -10,11 +10,10 @@ import { import { withDefaults } from "./utils/withDefaults"; import { useThreadConfig } from "./thread-config"; import { BranchPickerPrimitive } from "../primitives"; -import { useThreadContext } from "../context"; +import { useThread } from "../context"; const useAllowBranchPicker = (ensureCapability = false) => { const { branchPicker: { allowBranchPicker = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const branchPickerSupported = useThread((t) => t.capabilities.edit); return allowBranchPicker && (!ensureCapability || branchPickerSupported); }; diff --git a/packages/react/src/ui/composer-attachment.tsx b/packages/react/src/ui/composer-attachment.tsx index d41d41096..5783e73b3 100644 --- a/packages/react/src/ui/composer-attachment.tsx +++ b/packages/react/src/ui/composer-attachment.tsx @@ -9,8 +9,11 @@ import { TooltipIconButton, TooltipIconButtonProps, } from "./base/tooltip-icon-button"; -import { useThreadContext } from "../context/react/ThreadContext"; -import { useAttachmentContext } from "../context/react/AttachmentContext"; +import { useThreadComposerStore } from "../context/react/ThreadContext"; +import { + useAttachmentStore, + useComposerAttachment, +} from "../context/react/AttachmentContext"; const ComposerAttachmentRoot = withDefaults("div", { className: "aui-composer-attachment-root", @@ -19,8 +22,7 @@ const ComposerAttachmentRoot = withDefaults("div", { ComposerAttachmentRoot.displayName = "ComposerAttachmentRoot"; const ComposerAttachment: FC = () => { - const { useAttachment } = useAttachmentContext({ type: "composer" }); - const attachment = useAttachment((a) => a.attachment); + const attachment = useComposerAttachment((a) => a.attachment); return ( @@ -42,12 +44,12 @@ const ComposerAttachmentRemove = forwardRef< } = {}, } = useThreadConfig(); - const { useComposer } = useThreadContext(); - const { useAttachment } = useAttachmentContext(); + const composerStore = useThreadComposerStore(); + const attachmentStore = useAttachmentStore(); const handleRemoveAttachment = () => { - useComposer + composerStore .getState() - .removeAttachment(useAttachment.getState().attachment.id); + .removeAttachment(attachmentStore.getState().attachment.id); }; return ( diff --git a/packages/react/src/ui/composer.tsx b/packages/react/src/ui/composer.tsx index ec021f29f..58267eeaa 100644 --- a/packages/react/src/ui/composer.tsx +++ b/packages/react/src/ui/composer.tsx @@ -11,13 +11,12 @@ import { } from "./base/tooltip-icon-button"; import { CircleStopIcon } from "./base/CircleStopIcon"; import { ComposerPrimitive, ThreadPrimitive } from "../primitives"; -import { useThreadContext } from "../context/react/ThreadContext"; +import { useThread } from "../context/react/ThreadContext"; import ComposerAttachment from "./composer-attachment"; import { ComposerPrimitiveAttachmentsProps } from "../primitives/composer/ComposerAttachments"; const useAllowAttachments = (ensureCapability = false) => { const { composer: { allowAttachments = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const attachmentsSupported = useThread((t) => t.capabilities.attachments); return allowAttachments && (!ensureCapability || attachmentsSupported); }; @@ -117,7 +116,6 @@ const ComposerAddAttachment = forwardRef< ComposerAddAttachment.displayName = "ComposerAddAttachment"; const useAllowCancel = () => { - const { useThread } = useThreadContext(); const cancelSupported = useThread((t) => t.capabilities.cancel); return cancelSupported; }; diff --git a/packages/react/src/ui/thread-config.tsx b/packages/react/src/ui/thread-config.tsx index 3211e622c..79f3a453d 100644 --- a/packages/react/src/ui/thread-config.tsx +++ b/packages/react/src/ui/thread-config.tsx @@ -12,8 +12,9 @@ import { import { AvatarProps } from "./base/avatar"; import { TextContentPartComponent, ToolCallContentPartProps } from "../types"; import { AssistantRuntime } from "../runtimes"; -import { AssistantRuntimeProvider, useAssistantContext } from "../context"; +import { AssistantRuntimeProvider } from "../context"; import { AssistantToolUI } from "../model-config"; +import { useAssistantRuntimeStore } from "../context/react/AssistantContext"; export type SuggestionConfig = { text?: ReactNode; @@ -166,7 +167,7 @@ export const ThreadConfigProvider: FC = ({ children, config, }) => { - const assistant = useAssistantContext({ optional: true }); + const hasAssistant = !!useAssistantRuntimeStore({ optional: true }); const configProvider = config && Object.keys(config ?? {}).length > 0 ? ( @@ -178,7 +179,7 @@ export const ThreadConfigProvider: FC = ({ ); if (!config?.runtime) return configProvider; - if (assistant) { + if (hasAssistant) { throw new Error( "You provided a runtime to while simulataneously using . This is not allowed.", ); diff --git a/packages/react/src/ui/user-action-bar.tsx b/packages/react/src/ui/user-action-bar.tsx index 7078704bf..3ce444b9d 100644 --- a/packages/react/src/ui/user-action-bar.tsx +++ b/packages/react/src/ui/user-action-bar.tsx @@ -9,12 +9,11 @@ import { } from "./base/tooltip-icon-button"; import { withDefaults } from "./utils/withDefaults"; import { useThreadConfig } from "./thread-config"; -import { useThreadContext } from "../context"; +import { useThread } from "../context"; import { ActionBarPrimitive } from "../primitives"; const useAllowEdit = (ensureCapability = false) => { const { userMessage: { allowEdit = true } = {} } = useThreadConfig(); - const { useThread } = useThreadContext(); const editSupported = useThread((t) => t.capabilities.edit); return allowEdit && (!ensureCapability || editSupported); }; diff --git a/packages/react/src/ui/user-message-attachment.tsx b/packages/react/src/ui/user-message-attachment.tsx index 4a7aca54c..e88902f36 100644 --- a/packages/react/src/ui/user-message-attachment.tsx +++ b/packages/react/src/ui/user-message-attachment.tsx @@ -3,7 +3,7 @@ import { type FC } from "react"; import { withDefaults } from "./utils/withDefaults"; -import { useAttachmentContext } from "../context/react/AttachmentContext"; +import { useAttachment } from "../context/react/AttachmentContext"; const UserMessageAttachmentRoot = withDefaults("div", { className: "aui-user-message-attachment-root", @@ -12,7 +12,6 @@ const UserMessageAttachmentRoot = withDefaults("div", { UserMessageAttachmentRoot.displayName = "UserMessageAttachmentRoot"; const UserMessageAttachment: FC = () => { - const { useAttachment } = useAttachmentContext(); const attachment = useAttachment((a) => a.attachment); return ( diff --git a/packages/react/src/utils/hooks/useOnComposerFocus.tsx b/packages/react/src/utils/hooks/useOnComposerFocus.tsx index 5cb034506..59b0a3374 100644 --- a/packages/react/src/utils/hooks/useOnComposerFocus.tsx +++ b/packages/react/src/utils/hooks/useOnComposerFocus.tsx @@ -1,14 +1,14 @@ import { useCallbackRef } from "@radix-ui/react-use-callback-ref"; import { useEffect } from "react"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useThreadComposerStore } from "../../context/react/ThreadContext"; export const useOnComposerFocus = (callback: () => void) => { const callbackRef = useCallbackRef(callback); - const { useComposer } = useThreadContext(); + const threadComposerStore = useThreadComposerStore(); useEffect(() => { - return useComposer.getState().onFocus(() => { + return threadComposerStore.getState().onFocus(() => { callbackRef(); }); - }, [useComposer, callbackRef]); + }, [threadComposerStore, callbackRef]); }; diff --git a/packages/react/src/utils/hooks/useOnScrollToBottom.tsx b/packages/react/src/utils/hooks/useOnScrollToBottom.tsx index 966d9bdf5..45ec7cc7f 100644 --- a/packages/react/src/utils/hooks/useOnScrollToBottom.tsx +++ b/packages/react/src/utils/hooks/useOnScrollToBottom.tsx @@ -1,14 +1,14 @@ import { useCallbackRef } from "@radix-ui/react-use-callback-ref"; import { useEffect } from "react"; -import { useThreadContext } from "../../context/react/ThreadContext"; +import { useThreadViewportStore } from "../../context/react/ThreadContext"; export const useOnScrollToBottom = (callback: () => void) => { const callbackRef = useCallbackRef(callback); + const threadViewportStore = useThreadViewportStore(); - const { useViewport } = useThreadContext(); useEffect(() => { - return useViewport.getState().onScrollToBottom(() => { + return threadViewportStore.getState().onScrollToBottom(() => { callbackRef(); }); - }, [useViewport, callbackRef]); + }, [threadViewportStore, callbackRef]); }; diff --git a/packages/react/src/utils/smooth/SmoothContext.tsx b/packages/react/src/utils/smooth/SmoothContext.tsx index 16c94f183..9ac05f1eb 100644 --- a/packages/react/src/utils/smooth/SmoothContext.tsx +++ b/packages/react/src/utils/smooth/SmoothContext.tsx @@ -6,16 +6,19 @@ import { useContext, useState, } from "react"; -import { useContentPartContext } from "../../context"; import { ReadonlyStore } from "../../context/ReadonlyStore"; -import { create } from "zustand"; +import { create, UseBoundStore } from "zustand"; import { ContentPartStatus, ToolCallContentPartStatus, } from "../../types/AssistantTypes"; +import { useContentPartStore } from "../../context/react/ContentPartContext"; +import { createContextStoreHook } from "../../context/react/utils/createContextStoreHook"; type SmoothContextValue = { - useSmoothStatus: ReadonlyStore; + useSmoothStatus: UseBoundStore< + ReadonlyStore + >; }; const SmoothContext = createContext(null); @@ -29,10 +32,10 @@ const makeSmoothContext = ( export const SmoothContextProvider: FC = ({ children }) => { const outer = useSmoothContext({ optional: true }); - const { useContentPart } = useContentPartContext(); + const contentPartStore = useContentPartStore(); const [context] = useState(() => - makeSmoothContext(useContentPart.getState().status), + makeSmoothContext(contentPartStore.getState().status), ); // do not wrap if there is an outer SmoothContextProvider @@ -57,11 +60,13 @@ export const withSmoothContextProvider = >( return Wrapped as any; }; -export function useSmoothContext(): SmoothContextValue; -export function useSmoothContext(options: { - optional: true; +function useSmoothContext(options?: { + optional?: false | undefined; +}): SmoothContextValue; +function useSmoothContext(options?: { + optional?: boolean | undefined; }): SmoothContextValue | null; -export function useSmoothContext(options?: { optional: true }) { +function useSmoothContext(options?: { optional?: boolean | undefined }) { const context = useContext(SmoothContext); if (!options?.optional && !context) throw new Error( @@ -70,7 +75,7 @@ export function useSmoothContext(options?: { optional: true }) { return context; } -export const useSmoothStatus = () => { - const { useSmoothStatus } = useSmoothContext(); - return useSmoothStatus(); -}; +export const { useSmoothStatus, useSmoothStatusStore } = createContextStoreHook( + useSmoothContext, + "useSmoothStatus", +); diff --git a/packages/react/src/utils/smooth/useSmooth.tsx b/packages/react/src/utils/smooth/useSmooth.tsx index fcea48987..7e1608eed 100644 --- a/packages/react/src/utils/smooth/useSmooth.tsx +++ b/packages/react/src/utils/smooth/useSmooth.tsx @@ -1,15 +1,12 @@ "use client"; import { useEffect, useMemo, useRef, useState } from "react"; -import { useMessageContext } from "../../context"; -import { - ContentPartStatus, - ToolCallContentPartStatus, -} from "../../types/AssistantTypes"; +import { useMessage } from "../../context"; +import { ContentPartStatus } from "../../types/AssistantTypes"; import { TextContentPartState } from "../../context/stores/ContentPart"; -import { useSmoothContext } from "./SmoothContext"; -import { StoreApi } from "zustand"; import { useCallbackRef } from "@radix-ui/react-use-callback-ref"; +import { useSmoothStatus, useSmoothStatusStore } from "./SmoothContext"; +import { writableStore } from "../../context/ReadonlyStore"; class TextStreamAnimator { private animationFrameId: number | null = null; @@ -73,34 +70,31 @@ export const useSmooth = ( state: TextContentPartState, smooth: boolean = false, ): TextContentPartState => { - const { useSmoothStatus } = useSmoothContext({ optional: true }) ?? {}; - const { part: { text }, } = state; - const { useMessage } = useMessageContext(); const id = useMessage((m) => m.message.id); const idRef = useRef(id); const [displayedText, setDisplayedText] = useState(text); + const smoothStatusStore = useSmoothStatusStore({ optional: true }); const setText = useCallbackRef((text: string) => { setDisplayedText(text); - ( - useSmoothStatus as unknown as - | StoreApi - | undefined - )?.setState(text !== state.part.text ? SMOOTH_STATUS : state.status); + if (smoothStatusStore) { + writableStore(smoothStatusStore).setState( + text !== state.part.text ? SMOOTH_STATUS : state.status, + ); + } }); // TODO this is hacky useEffect(() => { - // TODO add a helper function so we don't have to override the types - ( - useSmoothStatus as unknown as - | StoreApi - | undefined - )?.setState(text !== displayedText ? SMOOTH_STATUS : state.status); + if (smoothStatusStore) { + writableStore(smoothStatusStore).setState( + text !== state.part.text ? SMOOTH_STATUS : state.status, + ); + } }, [useSmoothStatus, text, displayedText, state.status]); const [animatorRef] = useState(