diff --git a/.changeset/beige-sheep-rest.md b/.changeset/beige-sheep-rest.md
new file mode 100644
index 000000000..aa6e3e3cf
--- /dev/null
+++ b/.changeset/beige-sheep-rest.md
@@ -0,0 +1,5 @@
+---
+"@assistant-ui/react": patch
+---
+
+feat: Primitive Hook useThreadViewportAutoScroll
diff --git a/apps/www/pages/reference/primitives/Thread.mdx b/apps/www/pages/reference/primitives/Thread.mdx
index 97d090844..ab8a0280a 100644
--- a/apps/www/pages/reference/primitives/Thread.mdx
+++ b/apps/www/pages/reference/primitives/Thread.mdx
@@ -59,6 +59,23 @@ This primitive renders a `
` element unless `asChild` is set.
]}
/>
+#### `useThreadViewportAutoScroll`
+
+Returns a ref, which when set, will implement the viewport auto scroll behavior.
+Only useful you are creating a custom viewport component.
+
+```tsx
+import { useThreadViewportAutoScroll } from "@assistant-ui/react";
+
+const Viewport = () => {
+ const autoScrollRef = useThreadViewportAutoScroll();
+
+ return (
+
...
+ );
+}
+```
+
### Messages
Renders all messages. This primitive renders a separate component for each message.
diff --git a/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx
new file mode 100644
index 000000000..12273595e
--- /dev/null
+++ b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx
@@ -0,0 +1,88 @@
+"use client";
+import { useComposedRefs } from "@radix-ui/react-compose-refs";
+import { useRef } from "react";
+import { useThreadContext } from "../../context/react/ThreadContext";
+import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent";
+import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom";
+import { StoreApi } from "zustand";
+import { ThreadViewportState } from "../../context";
+import { useManagedRef } from "../../utils/hooks/useManagedRef";
+
+export type UseThreadViewportAutoScrollProps = {
+ autoScroll?: boolean | undefined;
+};
+
+export const useThreadViewportAutoScroll =
({
+ autoScroll = true,
+}: UseThreadViewportAutoScrollProps) => {
+ const divRef = useRef(null);
+
+ const { useViewport } = useThreadContext();
+
+ const firstRenderRef = useRef(true);
+ const lastScrollTop = useRef(0);
+
+ // bug: when ScrollToBottom's button changes its disabled state, the scroll stops
+ // fix: delay the state change until the scroll is done
+ const isScrollingToBottomRef = useRef(false);
+
+ const scrollToBottom = () => {
+ const div = divRef.current;
+ if (!div || !autoScroll) return;
+
+ const behavior = firstRenderRef.current ? "instant" : "auto";
+ firstRenderRef.current = false;
+
+ isScrollingToBottomRef.current = true;
+ div.scrollTo({ top: div.scrollHeight, behavior });
+ };
+
+ const handleScroll = () => {
+ const div = divRef.current;
+ if (!div) return;
+
+ const isAtBottom = useViewport.getState().isAtBottom;
+ const newIsAtBottom = div.scrollHeight - div.scrollTop <= div.clientHeight;
+
+ if (!newIsAtBottom && lastScrollTop.current < div.scrollTop) {
+ // ignore scroll down
+ } else {
+ isScrollingToBottomRef.current = newIsAtBottom;
+
+ if (newIsAtBottom !== isAtBottom) {
+ (useViewport as unknown as StoreApi).setState({
+ isAtBottom: newIsAtBottom,
+ });
+ }
+ }
+
+ lastScrollTop.current = div.scrollTop;
+ };
+
+ const resizeRef = useOnResizeContent(() => {
+ if (
+ !isScrollingToBottomRef.current &&
+ !useViewport.getState().isAtBottom &&
+ !firstRenderRef.current
+ ) {
+ handleScroll();
+ } else {
+ scrollToBottom();
+ }
+ });
+
+ const scrollRef = useManagedRef((el) => {
+ el.addEventListener("scroll", handleScroll);
+ return () => {
+ el.removeEventListener("scroll", handleScroll);
+ };
+ });
+
+ const autoScrollRef = useComposedRefs(resizeRef, scrollRef, divRef);
+
+ useOnScrollToBottom(() => {
+ scrollToBottom();
+ });
+
+ return autoScrollRef;
+};
diff --git a/packages/react/src/primitives/thread/ThreadViewport.tsx b/packages/react/src/primitives/thread/ThreadViewport.tsx
index 375869dc0..ac6207a51 100644
--- a/packages/react/src/primitives/thread/ThreadViewport.tsx
+++ b/packages/react/src/primitives/thread/ThreadViewport.tsx
@@ -1,90 +1,30 @@
"use client";
-import { composeEventHandlers } from "@radix-ui/primitive";
import { useComposedRefs } from "@radix-ui/react-compose-refs";
import { Primitive } from "@radix-ui/react-primitive";
+import { type ElementRef, forwardRef, ComponentPropsWithoutRef } from "react";
import {
- type ElementRef,
- forwardRef,
- useRef,
- ComponentPropsWithoutRef,
-} from "react";
-import { useThreadContext } from "../../context/react/ThreadContext";
-import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent";
-import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom";
-import { StoreApi } from "zustand";
-import { ThreadViewportState } from "../../context";
+ UseThreadViewportAutoScrollProps,
+ useThreadViewportAutoScroll,
+} from "../../primitive-hooks/thread/useThreadViewportAutoScroll";
type ThreadViewportElement = ElementRef;
type PrimitiveDivProps = ComponentPropsWithoutRef;
-type ThreadViewportProps = PrimitiveDivProps & {
- autoScroll?: boolean;
-};
+type ThreadViewportProps = PrimitiveDivProps & UseThreadViewportAutoScrollProps;
export const ThreadViewport = forwardRef<
ThreadViewportElement,
ThreadViewportProps
->(({ autoScroll = true, onScroll, children, ...rest }, forwardedRef) => {
- const divRef = useRef(null);
- const ref = useComposedRefs(forwardedRef, divRef);
-
- const { useViewport } = useThreadContext();
-
- // TODO find a more elegant implementation for this
-
- const firstRenderRef = useRef(true);
- const isScrollingToBottomRef = useRef(false);
- const lastScrollTop = useRef(0);
-
- const scrollToBottom = () => {
- const div = divRef.current;
- if (!div || !autoScroll) return;
-
- const behavior = firstRenderRef.current ? "instant" : "auto";
- firstRenderRef.current = false;
-
- isScrollingToBottomRef.current = true;
- div.scrollTo({ top: div.scrollHeight, behavior });
- };
-
- useOnResizeContent(divRef, () => {
- if (!isScrollingToBottomRef.current && !useViewport.getState().isAtBottom) {
- handleScroll();
- } else {
- scrollToBottom();
- }
- });
-
- useOnScrollToBottom(() => {
- scrollToBottom();
+>(({ autoScroll, onScroll, children, ...rest }, forwardedRef) => {
+ const autoScrollRef = useThreadViewportAutoScroll({
+ autoScroll,
});
- const handleScroll = () => {
- const div = divRef.current;
- if (!div) return;
-
- const isAtBottom = useViewport.getState().isAtBottom;
- const newIsAtBottom = div.scrollHeight - div.scrollTop <= div.clientHeight;
-
- if (!newIsAtBottom && lastScrollTop.current < div.scrollTop) {
- // ignore scroll down
- } else if (newIsAtBottom !== isAtBottom) {
- isScrollingToBottomRef.current = false;
- (useViewport as unknown as StoreApi).setState({
- isAtBottom: newIsAtBottom,
- });
- }
-
- lastScrollTop.current = div.scrollTop;
- };
+ const ref = useComposedRefs(forwardedRef, autoScrollRef);
return (
-
+
{children}
);
diff --git a/packages/react/src/utils/hooks/useManagedRef.ts b/packages/react/src/utils/hooks/useManagedRef.ts
new file mode 100644
index 000000000..7713dbab0
--- /dev/null
+++ b/packages/react/src/utils/hooks/useManagedRef.ts
@@ -0,0 +1,24 @@
+import { useCallback, useRef } from "react";
+
+export const useManagedRef = (
+ callback: (node: TNode) => (() => void) | void,
+) => {
+ const cleanupRef = useRef<(() => void) | void>();
+
+ const ref = useCallback(
+ (el: TNode | null) => {
+ // Call the previous cleanup function
+ if (cleanupRef.current) {
+ cleanupRef.current();
+ }
+
+ // Call the new callback and store its cleanup function
+ if (el) {
+ cleanupRef.current = callback(el);
+ }
+ },
+ [callback],
+ );
+
+ return ref;
+};
diff --git a/packages/react/src/utils/hooks/useOnResizeContent.tsx b/packages/react/src/utils/hooks/useOnResizeContent.tsx
index 5b24d6d06..f4bd8de22 100644
--- a/packages/react/src/utils/hooks/useOnResizeContent.tsx
+++ b/packages/react/src/utils/hooks/useOnResizeContent.tsx
@@ -1,48 +1,49 @@
import { useCallbackRef } from "@radix-ui/react-use-callback-ref";
-import { type MutableRefObject, useEffect } from "react";
+import { useCallback } from "react";
+import { useManagedRef } from "./useManagedRef";
-export const useOnResizeContent = (
- ref: MutableRefObject,
- callback: () => void,
-) => {
+export const useOnResizeContent = (callback: () => void) => {
const callbackRef = useCallbackRef(callback);
- useEffect(() => {
- const el = ref.current;
- if (!el) return;
-
- const resizeObserver = new ResizeObserver(() => {
- callbackRef();
- });
-
- const mutationObserver = new MutationObserver((mutations) => {
- for (const mutation of mutations) {
- for (const node of mutation.addedNodes) {
- if (node instanceof Element) {
- resizeObserver.observe(node);
+
+ const refCallback = useCallback(
+ (el: HTMLElement) => {
+ const resizeObserver = new ResizeObserver(() => {
+ callbackRef();
+ });
+
+ const mutationObserver = new MutationObserver((mutations) => {
+ for (const mutation of mutations) {
+ for (const node of mutation.addedNodes) {
+ if (node instanceof Element) {
+ resizeObserver.observe(node);
+ }
}
- }
- for (const node of mutation.removedNodes) {
- if (node instanceof Element) {
- resizeObserver.unobserve(node);
+ for (const node of mutation.removedNodes) {
+ if (node instanceof Element) {
+ resizeObserver.unobserve(node);
+ }
}
}
- }
- callbackRef();
- });
+ callbackRef();
+ });
- resizeObserver.observe(el);
- mutationObserver.observe(el, { childList: true });
+ resizeObserver.observe(el);
+ mutationObserver.observe(el, { childList: true });
+
+ // Observe existing children
+ for (const child of el.children) {
+ resizeObserver.observe(child);
+ }
- // Observe existing children
- for (const child of el.children) {
- resizeObserver.observe(child);
- }
+ return () => {
+ resizeObserver.disconnect();
+ mutationObserver.disconnect();
+ };
+ },
+ [callbackRef],
+ );
- return () => {
- resizeObserver.disconnect();
- mutationObserver.disconnect();
- };
- }, [ref, callbackRef]);
+ return useManagedRef(refCallback);
};