Skip to content

Commit

Permalink
feat: ThreadRuntime Events API (experimental) (#988)
Browse files Browse the repository at this point in the history
replace ComposerRuntime.focus with the new Events API
  • Loading branch information
Yonom authored Oct 13, 2024
1 parent f7c583b commit ad52c51
Show file tree
Hide file tree
Showing 21 changed files with 187 additions and 159 deletions.
6 changes: 3 additions & 3 deletions apps/docs/components/shadcn/Shadcn.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ArchiveIcon, EditIcon, MenuIcon, ShareIcon } from "lucide-react";
import Link from "next/link";
import { Thread, useSwitchToNewThread } from "@assistant-ui/react";
import { Thread, useAssistantRuntime } from "@assistant-ui/react";
import { makeMarkdownText } from "@assistant-ui/react-markdown";
import remarkGfm from "remark-gfm";
import { makePrismAsyncSyntaxHighlighter } from "@assistant-ui/react-syntax-highlighter";
Expand Down Expand Up @@ -61,11 +61,11 @@ const ButtonWithTooltip: FC<ButtonWithTooltipProps> = ({
};

const TopLeft: FC = () => {
const switchToNewThread = useSwitchToNewThread();
const runtime = useAssistantRuntime();

return (
<ButtonWithTooltip
onClick={switchToNewThread}
onClick={() => runtime.switchToNewThread()}
variant="ghost"
className="flex w-full justify-between px-3"
tooltip="New Chat"
Expand Down
5 changes: 5 additions & 0 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore {
throw new Error("PlaygroundRuntime does not support feedback.");
}

public unstable_on() {
// events not supported in playground
return () => {};
}

public deleteMessage(messageId: string) {
this.setMessages(this.messages.filter((m) => m.id !== messageId));
}
Expand Down
41 changes: 1 addition & 40 deletions packages/react/src/api/ComposerRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,6 @@ type LegacyThreadComposerState = Readonly<{
send: () => void;
/** @deprecated Use `useComposerRuntime().cancel` instead. This will be removed in 0.6.0. */
cancel: () => void;

// TODO replace with events
/** @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it. */
focus: () => void;
/** @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it. */
onFocus: (listener: () => void) => Unsubscribe;
}>;

type BaseComposerState = {
Expand Down Expand Up @@ -124,8 +118,6 @@ const METHOD_NOT_SUPPORTED = () => {
const EMPTY_ARRAY = Object.freeze([]);
const getThreadComposerState = (
runtime: ThreadComposerRuntimeCore | undefined,
focus: () => void,
onFocus: (listener: () => void) => Unsubscribe,
): ThreadComposerState => {
return Object.freeze({
type: "thread",
Expand All @@ -142,8 +134,6 @@ const getThreadComposerState = (
// edit: beginEdit,
send: runtime?.send.bind(runtime) ?? METHOD_NOT_SUPPORTED,
cancel: runtime?.cancel.bind(runtime) ?? METHOD_NOT_SUPPORTED,
focus: focus,
onFocus: onFocus,
reset: runtime?.reset.bind(runtime) ?? METHOD_NOT_SUPPORTED,

addAttachment: runtime?.addAttachment.bind(runtime) ?? METHOD_NOT_SUPPORTED,
Expand Down Expand Up @@ -336,12 +326,6 @@ export type ThreadComposerRuntime = Omit<
*/
attachments: readonly PendingAttachment[];

/** @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it. */
focus(): void;

/** @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it. */
onFocus(callback: () => void): Unsubscribe;

getAttachmentByIndex(
idx: number,
): AttachmentRuntime & { source: "thread-composer" };
Expand All @@ -359,12 +343,7 @@ export class ThreadComposerRuntimeImpl

constructor(core: ThreadComposerRuntimeCoreBinding) {
const stateBinding = new LazyMemoizeSubject({
getState: () =>
getThreadComposerState(
core.getState(),
this.focus.bind(this),
this.onFocus.bind(this),
),
getState: () => getThreadComposerState(core.getState()),
subscribe: (callback) => core.subscribe(callback),
});
super({
Expand All @@ -382,24 +361,6 @@ export class ThreadComposerRuntimeImpl
return this._getState();
}

// TODO replace with events
private _focusListeners = new Set<() => void>();

/**
* @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it.
*/
public focus() {
this._focusListeners.forEach((callback) => callback());
}

/**
* @deprecated This feature is being removed in 0.6.0. Submit feedback if you need it.
*/
public onFocus(callback: () => void) {
this._focusListeners.add(callback);
return () => this._focusListeners.delete(callback);
}

public getAttachmentByIndex(idx: number) {
return new ThreadComposerAttachmentRuntimeImpl(
new ShallowMemoizeSubject({
Expand Down
38 changes: 36 additions & 2 deletions packages/react/src/api/ThreadRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import {
} from "./MessageRuntime";
import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject";
import { ShallowMemoizeSubject } from "./subscribable/ShallowMemoizeSubject";
import { SubscribableWithState } from "./subscribable/Subscribable";
import {
Subscribable,
SubscribableWithState,
} from "./subscribable/Subscribable";
import {
ThreadComposerRuntime,
ThreadComposerRuntimeImpl,
Expand Down Expand Up @@ -64,7 +67,10 @@ const toAppendMessage = (
} as AppendMessage;
};

export type ThreadRuntimeCoreBinding = SubscribableWithState<ThreadRuntimeCore>;
export type ThreadRuntimeCoreBinding =
SubscribableWithState<ThreadRuntimeCore> & {
outerSubscribe(callback: () => void): Unsubscribe;
};

export type ThreadState = Readonly<{
threadId: string;
Expand Down Expand Up @@ -112,6 +118,10 @@ export type ThreadRuntime = {
import(repository: ExportedMessageRepository): void;
getMesssageByIndex(idx: number): MessageRuntime;
stopSpeaking: () => void;
unstable_on(
event: "switched-to" | "run-start",
callback: () => void,
): Unsubscribe;

// Legacy methods with deprecations

Expand Down Expand Up @@ -271,6 +281,7 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime {
this._threadBinding = {
getState: () => threadBinding.getState(),
getStateState: () => stateBinding.getState(),
outerSubscribe: (callback) => threadBinding.outerSubscribe(callback),
subscribe: (callback) => threadBinding.subscribe(callback),
};
}
Expand Down Expand Up @@ -413,4 +424,27 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime {
this._threadBinding,
);
}

private _eventListenerNestedSubscriptions = new Map<
string,
NestedSubscriptionSubject<Subscribable>
>();

public unstable_on(
event: "switched-to" | "run-start",
callback: () => void,
): Unsubscribe {
let subject = this._eventListenerNestedSubscriptions.get(event);
if (!subject) {
subject = new NestedSubscriptionSubject({
getState: () => ({
subscribe: (callback) =>
this._threadBinding.getState().unstable_on(event, callback),
}),
subscribe: (callback) => this._threadBinding.outerSubscribe(callback),
});
this._eventListenerNestedSubscriptions.set(event, subject);
}
return subject.subscribe(callback);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ export class NestedSubscriptionSubject<TState extends Subscribable | undefined>
return this.binding.getState();
}

public outerSubscribe(callback: () => void) {
return this.binding.subscribe(callback);
}

protected _connect(): Unsubscribe {
const callback = () => {
this.notifySubscribers();
Expand All @@ -36,7 +40,7 @@ export class NestedSubscriptionSubject<TState extends Subscribable | undefined>
callback();
};

const outerUnsubscribe = this.binding.subscribe(onRuntimeUpdate);
const outerUnsubscribe = this.outerSubscribe(onRuntimeUpdate);
return () => {
outerUnsubscribe?.();
innerUnsubscribe?.();
Expand Down
13 changes: 2 additions & 11 deletions packages/react/src/hooks/useAppendMessage.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
import { useCallback } from "react";
import { useThreadViewportStore } from "../context";
import {
useThreadComposerStore,
useThreadRuntime,
} from "../context/react/ThreadContext";
import { useThreadRuntime } from "../context/react/ThreadContext";
import { CreateAppendMessage } from "../api/ThreadRuntime";

/**
* @deprecated Use `useThreadRuntime().append()` instead. This will be removed in 0.6.
*/
export const useAppendMessage = () => {
const threadRuntime = useThreadRuntime();
const threadViewportStore = useThreadViewportStore();
const threadComposerStore = useThreadComposerStore();

const append = useCallback(
(message: CreateAppendMessage) => {
threadRuntime.append(message);

threadViewportStore.getState().scrollToBottom();
threadComposerStore.getState().focus();
},
[threadRuntime, threadViewportStore, threadComposerStore],
[threadRuntime],
);

return append;
Expand Down
5 changes: 1 addition & 4 deletions packages/react/src/hooks/useSwitchToNewThread.tsx
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import { useCallback } from "react";
import { useThreadComposerStore } from "../context/react/ThreadContext";
import { useAssistantRuntime } from "../context";

/**
* @deprecated Use `useRuntimeActions().switchToNewThread()` instead. This will be removed in 0.6.0.
*/
export const useSwitchToNewThread = () => {
const assistantRuntime = useAssistantRuntime();
const threadComposerStore = useThreadComposerStore();
const switchToNewThread = useCallback(() => {
assistantRuntime.switchToNewThread();
threadComposerStore.getState().focus();
}, [assistantRuntime, threadComposerStore]);
}, [assistantRuntime]);

return switchToNewThread;
};
Original file line number Diff line number Diff line change
@@ -1,32 +1,20 @@
import { useCallback } from "react";
import {
useMessageRuntime,
useMessageStore,
} from "../../context/react/MessageContext";
import {
useThreadComposerStore,
useThreadStore,
useThreadViewportStore,
} from "../../context/react/ThreadContext";
import { useMessageRuntime } from "../../context/react/MessageContext";
import { useThreadRuntime } from "../../context/react/ThreadContext";
import { useCombinedStore } from "../../utils/combined/useCombinedStore";

export const useActionBarReload = () => {
const messageStore = useMessageStore();
const threadStore = useThreadStore();
const messageRuntime = useMessageRuntime();
const threadComposerStore = useThreadComposerStore();
const threadViewportStore = useThreadViewportStore();
const threadRuntime = useThreadRuntime();

const disabled = useCombinedStore(
[threadStore, messageStore],
[threadRuntime, messageRuntime],
(t, m) => t.isRunning || t.isDisabled || m.role !== "assistant",
);

const callback = useCallback(() => {
messageRuntime.reload();
threadViewportStore.getState().scrollToBottom();
threadComposerStore.getState().focus();
}, [messageRuntime, threadComposerStore, threadViewportStore]);
}, [messageRuntime]);

if (disabled) return null;
return callback;
Expand Down
26 changes: 8 additions & 18 deletions packages/react/src/primitive-hooks/composer/useComposerSend.tsx
Original file line number Diff line number Diff line change
@@ -1,32 +1,22 @@
import { useCallback } from "react";
import { useComposerStore } from "../../context";
import { useCombinedStore } from "../../utils/combined/useCombinedStore";
import {
useThreadComposerStore,
useThreadStore,
useThreadViewportStore,
} from "../../context/react/ThreadContext";
import { useThreadRuntime } from "../../context/react/ThreadContext";
import { useComposerRuntime } from "../../context";

export const useComposerSend = () => {
const threadStore = useThreadStore();
const threadViewportStore = useThreadViewportStore();
const composerStore = useComposerStore();
const threadComposerStore = useThreadComposerStore();
const composerRuntime = useComposerRuntime();
const threadRuntime = useThreadRuntime();

const disabled = useCombinedStore(
[threadStore, composerStore],
[threadRuntime, composerRuntime],
(t, c) => t.isRunning || !c.isEditing || c.isEmpty,
);

const callback = useCallback(() => {
const composerState = composerStore.getState();
if (!composerState.isEditing) return;
if (!composerRuntime.getState().isEditing) return;

composerState.send();

threadViewportStore.getState().scrollToBottom();
threadComposerStore.getState().focus();
}, [threadComposerStore, composerStore, threadViewportStore]);
composerRuntime.send();
}, [threadRuntime]);

if (disabled) return null;
return callback;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import { useCallback } from "react";
import { useThreadViewport } from "../../context";
import {
useThreadComposerStore,
useThreadViewportStore,
} from "../../context/react/ThreadContext";
import { useThreadViewportStore } from "../../context/react/ThreadContext";

export const useThreadScrollToBottom = () => {
const isAtBottom = useThreadViewport((s) => s.isAtBottom);

const threadViewportStore = useThreadViewportStore();
const threadComposerStore = useThreadComposerStore();

const handleScrollToBottom = useCallback(() => {
threadViewportStore.getState().scrollToBottom();
threadComposerStore.getState().focus();
}, [threadViewportStore, threadComposerStore]);
}, [threadViewportStore]);

if (isAtBottom) return null;
return handleScrollToBottom;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
"use client";
import { useComposedRefs } from "@radix-ui/react-compose-refs";
import { useRef } from "react";
import { useThreadViewportStore } from "../../context/react/ThreadContext";
import { useEffect, useRef } from "react";
import {
useThreadRuntime,
useThreadViewportStore,
} from "../../context/react/ThreadContext";
import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent";
import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom";
import { useManagedRef } from "../../utils/hooks/useManagedRef";
import { writableStore } from "../../context/ReadonlyStore";

export type UseThreadViewportAutoScrollProps = {
autoScroll?: boolean | undefined;
unstable_scrollToBottomOnRunStart?: boolean | undefined;
};

export const useThreadViewportAutoScroll = <TElement extends HTMLElement>({
autoScroll = true,
unstable_scrollToBottomOnRunStart = true,
}: UseThreadViewportAutoScrollProps) => {
const divRef = useRef<TElement>(null);

Expand Down Expand Up @@ -81,5 +86,13 @@ export const useThreadViewportAutoScroll = <TElement extends HTMLElement>({
scrollToBottom("auto");
});

// autoscroll on run start
const threadRuntime = useThreadRuntime();
useEffect(() => {
if (!unstable_scrollToBottomOnRunStart) return undefined;

return threadRuntime.unstable_on("run-start", focus);
}, [unstable_scrollToBottomOnRunStart]);

return autoScrollRef;
};
Loading

0 comments on commit ad52c51

Please sign in to comment.