Skip to content

Commit

Permalink
refactor: use new runtime apis in hooks (#987)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Oct 13, 2024
1 parent ebdd76b commit f7c583b
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import { TransactionConfirmationPending } from "./transaction-confirmation-pending";
import { TransactionConfirmationFinal } from "./transaction-confirmation-final";
import { makeAssistantToolUI, useThreadStore } from "@assistant-ui/react";
import { makeAssistantToolUI, useThreadRuntime } from "@assistant-ui/react";
import { updateState } from "@/lib/chatApi";

type PurchaseStockArgs = {
Expand Down Expand Up @@ -32,9 +32,9 @@ export const PurchaseStockTool = makeAssistantToolUI<PurchaseStockArgs, string>(
? (JSON.parse(result) as { transactionId: string })
: undefined;

const threadStore = useThreadStore();
const threadRuntime = useThreadRuntime();
const handleConfirm = async () => {
await updateState(threadStore.getState().threadId, {
await updateState(threadRuntime.getState().threadId, {
newState: CONFIRM_PURCHASE,
asNode: PREPARE_PURCHASE_DETAILS_NODE,
});
Expand Down
4 changes: 2 additions & 2 deletions packages/react/src/api/ThreadRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ export type ThreadRuntime = {
submitFeedback: (feedback: SubmitFeedbackOptions) => void;

/**
* @deprecated Use `getMesssageById(id).getMessageByIndex(idx).composer` instead. This will be removed in 0.6.0.
* @deprecated Use `getMesssageById(id).composer` instead. This will be removed in 0.6.0.
*/
getEditComposer: (messageId: string) => ComposerRuntimeCore | undefined;

/**
* @deprecated Use `getMesssageById(id).getMessageByIndex(idx).composer.beginEdit()` instead. This will be removed in 0.6.0.
* @deprecated Use `getMesssageById(id).composer.beginEdit()` instead. This will be removed in 0.6.0.
*/
beginEdit: (messageId: string) => void;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { useCallback } from "react";
import {
useEditComposerStore,
useMessageStore,
useMessageRuntime,
useMessageUtilsStore,
} from "../../context/react/MessageContext";
import { useCombinedStore } from "../../utils/combined/useCombinedStore";
import { getThreadMessageText } from "../../utils/getThreadMessageText";
import { useComposerRuntime } from "../../context";

export type UseActionBarCopyProps = {
copiedDuration?: number | undefined;
Expand All @@ -14,11 +14,11 @@ export type UseActionBarCopyProps = {
export const useActionBarCopy = ({
copiedDuration = 3000,
}: UseActionBarCopyProps = {}) => {
const messageStore = useMessageStore();
const messageRuntime = useMessageRuntime();
const composerRuntime = useComposerRuntime();
const messageUtilsStore = useMessageUtilsStore();
const editComposerStore = useEditComposerStore();
const hasCopyableContent = useCombinedStore(
[messageStore, editComposerStore],
[messageRuntime, composerRuntime],
(message, c) => {
return (
!c.isEditing &&
Expand All @@ -29,9 +29,9 @@ export const useActionBarCopy = ({
);

const callback = useCallback(() => {
const message = messageStore.getState();
const message = messageRuntime.getState();
const { setIsCopied } = messageUtilsStore.getState();
const { isEditing, text: composerValue } = editComposerStore.getState();
const { isEditing, text: composerValue } = composerRuntime.getState();

const valueToCopy = isEditing
? composerValue
Expand All @@ -41,7 +41,7 @@ export const useActionBarCopy = ({
setIsCopied(true);
setTimeout(() => setIsCopied(false), copiedDuration);
});
}, [messageStore, messageUtilsStore, editComposerStore, copiedDuration]);
}, [messageRuntime, messageUtilsStore, composerRuntime, copiedDuration]);

if (!hasCopyableContent) return null;
return callback;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { useCallback } from "react";
import { useComposer, useComposerStore } from "../../context";
import { useComposer, useComposerRuntime } from "../../context";

export const useComposerCancel = () => {
const composerStore = useComposerStore();
const composerRuntime = useComposerRuntime();
const disabled = useComposer((c) => !c.canCancel);

const callback = useCallback(() => {
const { cancel } = composerStore.getState();
cancel();
}, [composerStore]);
composerRuntime.cancel();
}, [composerRuntime]);

if (disabled) return null;
return callback;
Expand Down
6 changes: 3 additions & 3 deletions packages/react/src/primitive-hooks/message/useMessageIf.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";
import {
useMessageStore,
useMessageRuntime,
useMessageUtilsStore,
} from "../../context/react/MessageContext";
import type { RequireAtLeastOne } from "../../utils/RequireAtLeastOne";
Expand All @@ -21,11 +21,11 @@ type MessageIfFilters = {
export type UseMessageIfProps = RequireAtLeastOne<MessageIfFilters>;

export const useMessageIf = (props: UseMessageIfProps) => {
const messageStore = useMessageStore();
const messageRuntime = useMessageRuntime();
const messageUtilsStore = useMessageUtilsStore();

return useCombinedStore(
[messageStore, messageUtilsStore],
[messageRuntime, messageUtilsStore],
(
{
role,
Expand Down
21 changes: 8 additions & 13 deletions packages/react/src/primitive-hooks/thread/useThreadSuggestion.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { useCallback } from "react";
import { useThread, useThreadStore } from "../../context";
import { useAppendMessage } from "../../hooks";
import { useThreadComposerStore } from "../../context/react/ThreadContext";
import { useThread } from "../../context";
import { useThreadRuntime } from "../../context/react/ThreadContext";

export type UseApplyThreadSuggestionProps = {
prompt: string;
Expand All @@ -13,21 +12,17 @@ export const useThreadSuggestion = ({
prompt,
autoSend,
}: UseApplyThreadSuggestionProps) => {
const threadStore = useThreadStore();
const composerStore = useThreadComposerStore();
const threadRuntime = useThreadRuntime();

const append = useAppendMessage();
const disabled = useThread((t) => t.isDisabled);
const callback = useCallback(() => {
const thread = threadStore.getState();
const composer = composerStore.getState();
if (autoSend && !thread.isRunning) {
append(prompt);
composer.setText("");
if (autoSend && !threadRuntime.getState().isRunning) {
threadRuntime.append(prompt);
threadRuntime.composer.setText("");
} else {
composer.setText(prompt);
threadRuntime.composer.setText(prompt);
}
}, [threadStore, composerStore, autoSend, append, prompt]);
}, [threadRuntime, autoSend, prompt]);

if (disabled) return null;
return callback;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"use client";
import {
useMessageStore,
useMessageRuntime,
useMessageUtilsStore,
} from "../../context/react/MessageContext";
import { useThreadStore } from "../../context/react/ThreadContext";
import { useThreadRuntime } from "../../context/react/ThreadContext";
import { useCombinedStore } from "../../utils/combined/useCombinedStore";

export enum HideAndFloatStatus {
Expand All @@ -23,12 +23,12 @@ export const useActionBarFloatStatus = ({
autohide,
autohideFloat,
}: UseActionBarFloatStatusProps) => {
const threadStore = useThreadStore();
const messageStore = useMessageStore();
const threadRuntime = useThreadRuntime();
const messageRuntime = useMessageRuntime();
const messageUtilsStore = useMessageUtilsStore();

return useCombinedStore(
[threadStore, messageStore, messageUtilsStore],
[threadRuntime, messageRuntime, messageUtilsStore],
(t, m, mu) => {
if (hideWhenRunning && t.isRunning) return HideAndFloatStatus.Hidden;

Expand Down
27 changes: 15 additions & 12 deletions packages/react/src/primitives/composer/ComposerInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ import TextareaAutosize, {
} from "react-textarea-autosize";
import {
useComposer,
useComposerStore,
useComposerRuntime,
} from "../../context/react/ComposerContext";
import { useThread, useThreadStore } from "../../context/react/ThreadContext";
import {
useThread,
useThreadRuntime,
useThreadViewportStore,
} from "../../context/react/ThreadContext";
import { useEscapeKeydown } from "@radix-ui/react-use-escape-keydown";
import { useOnComposerFocus } from "../../utils/hooks/useOnComposerFocus";

Expand Down Expand Up @@ -52,8 +56,9 @@ export const ComposerPrimitiveInput = forwardRef<
},
forwardedRef,
) => {
const threadStore = useThreadStore();
const composerStore = useComposerStore();
const threadRuntime = useThreadRuntime();
const composerRuntime = useComposerRuntime();
const threadViewportStore = useThreadViewportStore({ optional: true });

const value = useComposer((c) => {
if (!c.isEditing) return "";
Expand All @@ -69,9 +74,8 @@ export const ComposerPrimitiveInput = forwardRef<
useEscapeKeydown((e) => {
if (!cancelOnEscape) return;

const composer = composerStore.getState();
if (composer.canCancel) {
composer.cancel();
if (composerRuntime.getState().canCancel) {
composerRuntime.cancel();
e.preventDefault();
}
});
Expand All @@ -83,7 +87,7 @@ export const ComposerPrimitiveInput = forwardRef<
if (e.nativeEvent.isComposing) return;

if (e.key === "Enter" && e.shiftKey === false) {
const { isRunning } = threadStore.getState();
const { isRunning } = threadRuntime.getState();

if (!isRunning) {
e.preventDefault();
Expand All @@ -108,7 +112,7 @@ export const ComposerPrimitiveInput = forwardRef<
useEffect(() => focus(), [focus]);

useOnComposerFocus(() => {
if (composerStore.getState().type === "thread") {
if (composerRuntime.type === "thread") {
focus();
}
});
Expand All @@ -121,9 +125,8 @@ export const ComposerPrimitiveInput = forwardRef<
ref={ref}
disabled={isDisabled}
onChange={composeEventHandlers(onChange, (e) => {
const composerState = composerStore.getState();
if (!composerState.isEditing) return;
return composerState.setText(e.target.value);
if (!composerRuntime.getState().isEditing) return;
return composerRuntime.setText(e.target.value);
})}
onKeyDown={composeEventHandlers(onKeyDown, handleKeyPress)}
/>
Expand Down
17 changes: 7 additions & 10 deletions packages/react/src/primitives/message/MessageContent.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"use client";

import { type ComponentType, type FC, memo, useMemo } from "react";
import { useContentPart, useThreadRuntime, useToolUIs } from "../../context";
import {
useContentPart,
useContentPartRuntime,
useToolUIs,
} from "../../context";
import {
useMessage,
useMessageRuntime,
useMessageStore,
} from "../../context/react/MessageContext";
import { ContentPartRuntimeProvider } from "../../context/providers/ContentPartRuntimeProvider";
import { ContentPartPrimitiveText } from "../contentPart/ContentPartText";
Expand Down Expand Up @@ -85,8 +88,7 @@ const MessageContentPartComponent: FC<MessageContentPartComponentProps> = ({
tools: { by_name = {}, Fallback = undefined } = {},
} = {},
}) => {
const messageStore = useMessageStore();
const threadRuntime = useThreadRuntime();
const contentPartRuntime = useContentPartRuntime();

const part = useContentPart();

Expand Down Expand Up @@ -115,12 +117,7 @@ const MessageContentPartComponent: FC<MessageContentPartComponentProps> = ({
case "tool-call": {
const Tool = by_name[part.toolName] || Fallback;
const addResult = (result: any) =>
threadRuntime.addToolResult({
messageId: messageStore.getState().id,
toolName: part.toolName,
toolCallId: part.toolCallId,
result,
});
contentPartRuntime.addToolResult(result);
return (
<ToolUIDisplay {...part} part={part} UI={Tool} addResult={addResult} />
);
Expand Down

0 comments on commit f7c583b

Please sign in to comment.