Skip to content

Commit

Permalink
feat: thread disabling support (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jul 27, 2024
1 parent 6de45fb commit 9dc942f
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 17 deletions.
7 changes: 7 additions & 0 deletions .changeset/five-wasps-double.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@assistant-ui/react-playground": patch
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

feat: useThread.isDisabled flag
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export class VercelUseAssistantThreadRuntime implements ReactThreadRuntime {
>;

public messages: readonly ThreadMessage[] = [];
public readonly isDisabled = false;
public isRunning = false;

constructor(public vercel: ReturnType<typeof useAssistant>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export class VercelUseChatThreadRuntime implements ReactThreadRuntime {
public readonly capabilities = CAPABILITIES;

public messages: ThreadMessage[] = [];
public readonly isDisabled = false;
public isRunning = false;

constructor(public vercel: ReturnType<typeof useChat>) {
Expand Down
1 change: 1 addition & 0 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ export class PlaygroundThreadRuntime implements ReactThreadRuntime {

public tools: Record<string, Tool<any, any>> = {};

public readonly isDisabled = false;
public get isRunning() {
return this.abortController != null;
}
Expand Down
8 changes: 7 additions & 1 deletion packages/react/src/context/providers/ThreadProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
useCallback(
(thread: ReactThreadRuntime) => {
const onThreadUpdate = () => {
if (thread.isRunning !== context.useThread.getState().isRunning) {
const threadState = context.useThread.getState();
if (
thread.isRunning !== threadState.isRunning ||
thread.isDisabled !== threadState.isDisabled
) {
(context.useThread as unknown as StoreApi<ThreadState>).setState(
Object.freeze({
isRunning: thread.isRunning,
isDisabled: thread.isDisabled,
}) satisfies ThreadState,
true,
);
}

if (thread.messages !== context.useThreadMessages.getState()) {
(
context.useThreadMessages as unknown as StoreApi<ThreadMessagesState>
Expand Down
2 changes: 2 additions & 0 deletions packages/react/src/context/stores/Thread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import { ThreadRuntimeStore } from "./ThreadRuntime";

export type ThreadState = Readonly<{
isRunning: boolean;
isDisabled: boolean;
}>;

export const makeThreadStore = (
runtimeRef: ReadonlyStore<ThreadRuntimeStore>,
) => {
return create<ThreadState>(() => ({
isDisabled: runtimeRef.getState().isDisabled,
isRunning: runtimeRef.getState().isRunning,
}));
};
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const useActionBarReload = () => {

const disabled = useCombinedStore(
[useThread, useMessage],
(t, m) => t.isRunning || m.message.role !== "assistant",
(t, m) => t.isRunning || t.isDisabled || m.message.role !== "assistant",
);

const callback = useCallback(() => {
Expand Down
13 changes: 11 additions & 2 deletions packages/react/src/primitive-hooks/composer/useComposerSend.tsx
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import { useCallback } from "react";
import { useComposerContext, useThreadContext } from "../../context";
import { useCombinedStore } from "../../utils/combined/useCombinedStore";

export const useComposerSend = () => {
const { useViewport, useComposer: useNewComposer } = useThreadContext();
const {
useThread,
useViewport,
useComposer: useNewComposer,
} = useThreadContext();
const { useComposer } = useComposerContext();

const disabled = useComposer((c) => !c.isEditing || c.value.length === 0);
const disabled = useCombinedStore(
[useThread, useComposer],
(t, c) =>
t.isDisabled || t.isRunning || !c.isEditing || c.value.length === 0,
);

const callback = useCallback(() => {
const composerState = useComposer.getState();
Expand Down
3 changes: 3 additions & 0 deletions packages/react/src/primitive-hooks/thread/useThreadIf.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { useCombinedStore } from "../../utils/combined/useCombinedStore";
type ThreadIfFilters = {
empty: boolean | undefined;
running: boolean | undefined;
disabled: boolean | undefined;
};

export type UseThreadIfProps = RequireAtLeastOne<ThreadIfFilters>;
Expand All @@ -20,6 +21,8 @@ export const useThreadIf = (props: UseThreadIfProps) => {
if (props.empty === false && messages.length === 0) return false;
if (props.running === true && !thread.isRunning) return false;
if (props.running === false && thread.isRunning) return false;
if (props.disabled === true && thread.isDisabled) return false;
if (props.disabled === false && thread.isDisabled) return false;

return true;
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const useThreadSuggestion = ({
const { useThread, useComposer } = useThreadContext();

const append = useAppendMessage();
const disabled = useThread((t) => t.isRunning);
const disabled = useThread((t) => t.isDisabled);
const callback = useCallback(() => {
const thread = useThread.getState();
const composer = useComposer.getState();
Expand Down
16 changes: 12 additions & 4 deletions packages/react/src/primitives/composer/ComposerInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ export const ComposerPrimitiveInput = forwardRef<
ComposerPrimitiveInputProps
>(
(
{ autoFocus = false, asChild, disabled, onChange, onKeyDown, ...rest },
{
autoFocus = false,
asChild,
disabled: disabledProp,
onChange,
onKeyDown,
...rest
},
forwardedRef,
) => {
const { useThread } = useThreadContext();
Expand All @@ -40,6 +47,7 @@ export const ComposerPrimitiveInput = forwardRef<

const Component = asChild ? Slot : TextareaAutosize;

const isDisabled = useThread((t) => t.isDisabled) ?? disabledProp ?? false;
const textareaRef = useRef<HTMLTextAreaElement>(null);
const ref = useComposedRefs(forwardedRef, textareaRef);

Expand All @@ -52,7 +60,7 @@ export const ComposerPrimitiveInput = forwardRef<
});

const handleKeyPress = (e: KeyboardEvent) => {
if (disabled) return;
if (isDisabled) return;

if (e.key === "Enter" && e.shiftKey === false) {
const isRunning = useThread.getState().isRunning;
Expand All @@ -64,7 +72,7 @@ export const ComposerPrimitiveInput = forwardRef<
}
};

const autoFocusEnabled = autoFocus && !disabled;
const autoFocusEnabled = autoFocus && !isDisabled;
const focus = useCallback(() => {
const textarea = textareaRef.current;
if (!textarea || !autoFocusEnabled) return;
Expand All @@ -90,7 +98,7 @@ export const ComposerPrimitiveInput = forwardRef<
value={value}
{...rest}
ref={ref}
disabled={disabled}
disabled={isDisabled}
onChange={composeEventHandlers(onChange, (e) => {
const composerState = useComposer.getState();
if (!composerState.isEditing) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type ExternalStoreMessageConverterAdapter<T> = {

type ExternalStoreAdapterBase<T> = {
threadId?: string | undefined;
isDisabled?: boolean | undefined;
isRunning?: boolean | undefined;
messages: T[];
setMessages?: ((messages: T[]) => void) | undefined;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { type StoreApi, type UseBoundStore, create } from "zustand";
import { create } from "zustand";
import { ReactThreadRuntime } from "../core";
import { MessageRepository } from "../utils/MessageRepository";
import { AppendMessage, ThreadMessage, Unsubscribe } from "../../types";
Expand All @@ -19,9 +19,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
private repository = new MessageRepository();
private assistantOptimisticId: string | null = null;

private useStore: UseBoundStore<
StoreApi<{ store: ExternalStoreAdapter<any> }>
>;
private useStore;

public get capabilities() {
return {
Expand All @@ -33,10 +31,15 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
};
}

public messages: ThreadMessage[] = [];
public isRunning = false;
public messages;
public isDisabled;
public isRunning;

constructor(public store: ExternalStoreAdapter<any>) {
this.isDisabled = store.isDisabled ?? false;
this.isRunning = store.isRunning ?? false;
this.messages = store.messages;

this.useStore = create(() => ({
store,
}));
Expand Down Expand Up @@ -107,7 +110,11 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
}
}

private updateData = (isRunning: boolean, vm: ThreadMessage[]) => {
private updateData = (
isDisabled: boolean,
isRunning: boolean,
vm: ThreadMessage[],
) => {
for (let i = 0; i < vm.length; i++) {
const message = vm[i]!;
const parent = vm[i - 1];
Expand All @@ -134,6 +141,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {
);

this.messages = this.repository.getMessages();
this.isDisabled = isDisabled;
this.isRunning = isRunning;

for (const callback of this._subscriptions) callback();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { getAutoStatus } from "./auto-status";
import { fromThreadMessageLike } from "./ThreadMessageLike";

type UpdateDataCallback = (
isDisabled: boolean,
isRunning: boolean,
messages: ThreadMessage[],
) => void;
Expand Down Expand Up @@ -50,14 +51,16 @@ export const useExternalStoreSync = <T extends WeakKey>(

useEffect(() => {
updateData(
adapter.isDisabled ?? false,
adapter.isRunning ?? false,
converter.convertMessages(adapter.messages, convertCallback),
);
}, [
updateData,
converter,
convertCallback,
adapter.messages,
adapter.isDisabled,
adapter.isRunning,
adapter.messages,
]);
};
2 changes: 2 additions & 0 deletions packages/react/src/runtimes/local/LocalThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ export class LocalThreadRuntime implements ThreadRuntime {

public readonly capabilities = CAPABILITIES;

public readonly isDisabled = true;

public get messages() {
return this.repository.getMessages();
}
Expand Down

0 comments on commit 9dc942f

Please sign in to comment.