Skip to content

Commit

Permalink
refactor: helper function to set zustand stores (#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Sep 7, 2024
1 parent e58d61b commit 9e00772
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .changeset/olive-olives-learn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: add composer attachments state
4 changes: 4 additions & 0 deletions packages/react/src/context/ReadonlyStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ import type { StoreApi, UseBoundStore } from "zustand";
export type ReadonlyStore<T> = UseBoundStore<
Omit<StoreApi<T>, "setState" | "destroy">
>;

export const writableStore = <T>(store: ReadonlyStore<T> | undefined) => {
return store as unknown as UseBoundStore<StoreApi<T>>;
};
12 changes: 3 additions & 9 deletions packages/react/src/context/providers/AssistantProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ import { makeAssistantModelConfigStore } from "../stores/AssistantModelConfig";
import { makeAssistantToolUIsStore } from "../stores/AssistantToolUIs";
import { ThreadProvider } from "./ThreadProvider";
import { makeAssistantActionsStore } from "../stores/AssistantActions";
import {
AssistantRuntimeStore,
makeAssistantRuntimeStore,
} from "../stores/AssistantRuntime";
import { StoreApi } from "zustand";
import { makeAssistantRuntimeStore } from "../stores/AssistantRuntime";
import { writableStore } from "../ReadonlyStore";

type AssistantProviderProps = {
runtime: AssistantRuntime;
Expand Down Expand Up @@ -46,10 +43,7 @@ export const AssistantProvider: FC<
}, [runtime, getModelConfig]);

useEffect(
() =>
(
context.useAssistantRuntime as unknown as StoreApi<AssistantRuntimeStore>
).setState(runtime, true),
() => writableStore(context.useAssistantRuntime).setState(runtime, true),
[runtime, context],
);

Expand Down
7 changes: 3 additions & 4 deletions packages/react/src/context/providers/ContentPartProvider.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"use client";

import { type FC, type PropsWithChildren, useEffect, useState } from "react";
import { StoreApi, create } from "zustand";
import { create } from "zustand";
import { ContentPartContext } from "../react/ContentPartContext";
import type { ContentPartContextValue } from "../react/ContentPartContext";
import { useMessageContext } from "../react/MessageContext";
Expand All @@ -14,6 +14,7 @@ import {
ThreadUserContentPart,
ToolCallContentPartStatus,
} from "../../types/AssistantTypes";
import { writableStore } from "../ReadonlyStore";

type ContentPartProviderProps = PropsWithChildren<{
partIndex: number;
Expand Down Expand Up @@ -99,9 +100,7 @@ const useContentPartContext = (partIndex: number) => {
partIndex,
);
if (!newState) return;
(
context.useContentPart as unknown as StoreApi<ContentPartState>
).setState(newState, true);
writableStore(context.useContentPart).setState(newState, true);
};

syncContentPart(useMessage.getState());
Expand Down
8 changes: 3 additions & 5 deletions packages/react/src/context/providers/MessageProvider.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"use client";

import { type FC, type PropsWithChildren, useEffect, useState } from "react";
import { StoreApi, create } from "zustand";
import { create } from "zustand";
import type {
CoreUserContentPart,
ThreadMessage,
Expand All @@ -14,6 +14,7 @@ import type { MessageState } from "../stores/Message";
import { makeEditComposerStore } from "../stores/EditComposer";
import { makeMessageUtilsStore } from "../stores/MessageUtils";
import { ThreadMessagesState } from "../stores/ThreadMessages";
import { writableStore } from "../ReadonlyStore";

type MessageProviderProps = PropsWithChildren<{
messageIndex: number;
Expand Down Expand Up @@ -113,10 +114,7 @@ const useMessageContext = (messageIndex: number) => {
messageIndex,
);
if (!newState) return;
(context.useMessage as unknown as StoreApi<MessageState>).setState(
newState,
true,
);
writableStore(context.useMessage).setState(newState, true);
};

syncMessage(useThreadMessages.getState());
Expand Down
38 changes: 12 additions & 26 deletions packages/react/src/context/providers/ThreadProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,15 @@ import { useEffect, useInsertionEffect, useState } from "react";
import type { ReactThreadRuntime } from "../../runtimes/core/ReactThreadRuntime";
import type { ThreadContextValue } from "../react/ThreadContext";
import { ThreadContext } from "../react/ThreadContext";
import { ComposerState, makeComposerStore } from "../stores/Composer";
import {
ThreadState,
getThreadStateFromRuntime,
makeThreadStore,
} from "../stores/Thread";
import { makeComposerStore } from "../stores/Composer";
import { getThreadStateFromRuntime, makeThreadStore } from "../stores/Thread";
import { makeThreadViewportStore } from "../stores/ThreadViewport";
import { makeThreadActionStore } from "../stores/ThreadActions";
import { StoreApi } from "zustand";
import {
ThreadMessagesState,
makeThreadMessagesStore,
} from "../stores/ThreadMessages";
import { makeThreadMessagesStore } from "../stores/ThreadMessages";
import { ThreadRuntimeWithSubscribe } from "../../runtimes/core/AssistantRuntime";
import {
makeThreadRuntimeStore,
ThreadRuntimeStore,
} from "../stores/ThreadRuntime";
import { makeThreadRuntimeStore } from "../stores/ThreadRuntime";
import { subscribeToMainThread } from "../../runtimes";
import { writableStore } from "../ReadonlyStore";

type ThreadProviderProps = {
provider: ThreadRuntimeWithSubscribe;
Expand Down Expand Up @@ -63,16 +53,14 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
// TODO ensure capabilities is memoized
oldState.capabilities !== state.capabilities
) {
(context.useThread as unknown as StoreApi<ThreadState>).setState(
state,
true,
);
writableStore(context.useThread).setState(state, true);
}

if (thread.messages !== context.useThreadMessages.getState()) {
(
context.useThreadMessages as unknown as StoreApi<ThreadMessagesState>
).setState(thread.messages, true);
writableStore(context.useThreadMessages).setState(
thread.messages,
true,
);
}

const composerState = context.useComposer.getState();
Expand All @@ -81,7 +69,7 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
thread.composer.attachments !== composerState.attachments ||
state.capabilities.cancel !== composerState.canCancel
) {
(context.useComposer as unknown as StoreApi<ComposerState>).setState({
writableStore(context.useComposer).setState({
text: thread.composer.text,
attachments: thread.composer.attachments,
canCancel: state.capabilities.cancel,
Expand All @@ -96,9 +84,7 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
useInsertionEffect(
() =>
provider.subscribe(() => {
(
context.useThreadRuntime as unknown as StoreApi<ThreadRuntimeStore>
).setState(provider.thread, true);
writableStore(context.useThreadRuntime).setState(provider.thread, true);
}),
[provider, context],
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import { useRef } from "react";
import { useThreadContext } from "../../context/react/ThreadContext";
import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent";
import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom";
import { StoreApi } from "zustand";
import { ThreadViewportState } from "../../context";
import { useManagedRef } from "../../utils/hooks/useManagedRef";
import { writableStore } from "../../context/ReadonlyStore";

export type UseThreadViewportAutoScrollProps = {
autoScroll?: boolean | undefined;
Expand Down Expand Up @@ -49,9 +48,7 @@ export const useThreadViewportAutoScroll = <TElement extends HTMLElement>({
}

if (newIsAtBottom !== isAtBottom) {
(useViewport as unknown as StoreApi<ThreadViewportState>).setState({
isAtBottom: newIsAtBottom,
});
writableStore(useViewport).setState({ isAtBottom: newIsAtBottom });
}
}

Expand Down

0 comments on commit 9e00772

Please sign in to comment.