diff --git a/.changeset/beige-sheep-rest.md b/.changeset/beige-sheep-rest.md new file mode 100644 index 000000000..aa6e3e3cf --- /dev/null +++ b/.changeset/beige-sheep-rest.md @@ -0,0 +1,5 @@ +--- +"@assistant-ui/react": patch +--- + +feat: Primitive Hook useThreadViewportAutoScroll diff --git a/apps/www/pages/reference/primitives/Thread.mdx b/apps/www/pages/reference/primitives/Thread.mdx index 97d090844..ab8a0280a 100644 --- a/apps/www/pages/reference/primitives/Thread.mdx +++ b/apps/www/pages/reference/primitives/Thread.mdx @@ -59,6 +59,23 @@ This primitive renders a `
` element unless `asChild` is set. ]} /> +#### `useThreadViewportAutoScroll` + +Returns a ref, which when set, will implement the viewport auto scroll behavior. +Only useful you are creating a custom viewport component. + +```tsx +import { useThreadViewportAutoScroll } from "@assistant-ui/react"; + +const Viewport = () => { + const autoScrollRef = useThreadViewportAutoScroll(); + + return ( +
...
+ ); +} +``` + ### Messages Renders all messages. This primitive renders a separate component for each message. diff --git a/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx new file mode 100644 index 000000000..12273595e --- /dev/null +++ b/packages/react/src/primitive-hooks/thread/useThreadViewportAutoScroll.tsx @@ -0,0 +1,88 @@ +"use client"; +import { useComposedRefs } from "@radix-ui/react-compose-refs"; +import { useRef } from "react"; +import { useThreadContext } from "../../context/react/ThreadContext"; +import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent"; +import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom"; +import { StoreApi } from "zustand"; +import { ThreadViewportState } from "../../context"; +import { useManagedRef } from "../../utils/hooks/useManagedRef"; + +export type UseThreadViewportAutoScrollProps = { + autoScroll?: boolean | undefined; +}; + +export const useThreadViewportAutoScroll = ({ + autoScroll = true, +}: UseThreadViewportAutoScrollProps) => { + const divRef = useRef(null); + + const { useViewport } = useThreadContext(); + + const firstRenderRef = useRef(true); + const lastScrollTop = useRef(0); + + // bug: when ScrollToBottom's button changes its disabled state, the scroll stops + // fix: delay the state change until the scroll is done + const isScrollingToBottomRef = useRef(false); + + const scrollToBottom = () => { + const div = divRef.current; + if (!div || !autoScroll) return; + + const behavior = firstRenderRef.current ? "instant" : "auto"; + firstRenderRef.current = false; + + isScrollingToBottomRef.current = true; + div.scrollTo({ top: div.scrollHeight, behavior }); + }; + + const handleScroll = () => { + const div = divRef.current; + if (!div) return; + + const isAtBottom = useViewport.getState().isAtBottom; + const newIsAtBottom = div.scrollHeight - div.scrollTop <= div.clientHeight; + + if (!newIsAtBottom && lastScrollTop.current < div.scrollTop) { + // ignore scroll down + } else { + isScrollingToBottomRef.current = newIsAtBottom; + + if (newIsAtBottom !== isAtBottom) { + (useViewport as unknown as StoreApi).setState({ + isAtBottom: newIsAtBottom, + }); + } + } + + lastScrollTop.current = div.scrollTop; + }; + + const resizeRef = useOnResizeContent(() => { + if ( + !isScrollingToBottomRef.current && + !useViewport.getState().isAtBottom && + !firstRenderRef.current + ) { + handleScroll(); + } else { + scrollToBottom(); + } + }); + + const scrollRef = useManagedRef((el) => { + el.addEventListener("scroll", handleScroll); + return () => { + el.removeEventListener("scroll", handleScroll); + }; + }); + + const autoScrollRef = useComposedRefs(resizeRef, scrollRef, divRef); + + useOnScrollToBottom(() => { + scrollToBottom(); + }); + + return autoScrollRef; +}; diff --git a/packages/react/src/primitives/thread/ThreadViewport.tsx b/packages/react/src/primitives/thread/ThreadViewport.tsx index 375869dc0..ac6207a51 100644 --- a/packages/react/src/primitives/thread/ThreadViewport.tsx +++ b/packages/react/src/primitives/thread/ThreadViewport.tsx @@ -1,90 +1,30 @@ "use client"; -import { composeEventHandlers } from "@radix-ui/primitive"; import { useComposedRefs } from "@radix-ui/react-compose-refs"; import { Primitive } from "@radix-ui/react-primitive"; +import { type ElementRef, forwardRef, ComponentPropsWithoutRef } from "react"; import { - type ElementRef, - forwardRef, - useRef, - ComponentPropsWithoutRef, -} from "react"; -import { useThreadContext } from "../../context/react/ThreadContext"; -import { useOnResizeContent } from "../../utils/hooks/useOnResizeContent"; -import { useOnScrollToBottom } from "../../utils/hooks/useOnScrollToBottom"; -import { StoreApi } from "zustand"; -import { ThreadViewportState } from "../../context"; + UseThreadViewportAutoScrollProps, + useThreadViewportAutoScroll, +} from "../../primitive-hooks/thread/useThreadViewportAutoScroll"; type ThreadViewportElement = ElementRef; type PrimitiveDivProps = ComponentPropsWithoutRef; -type ThreadViewportProps = PrimitiveDivProps & { - autoScroll?: boolean; -}; +type ThreadViewportProps = PrimitiveDivProps & UseThreadViewportAutoScrollProps; export const ThreadViewport = forwardRef< ThreadViewportElement, ThreadViewportProps ->(({ autoScroll = true, onScroll, children, ...rest }, forwardedRef) => { - const divRef = useRef(null); - const ref = useComposedRefs(forwardedRef, divRef); - - const { useViewport } = useThreadContext(); - - // TODO find a more elegant implementation for this - - const firstRenderRef = useRef(true); - const isScrollingToBottomRef = useRef(false); - const lastScrollTop = useRef(0); - - const scrollToBottom = () => { - const div = divRef.current; - if (!div || !autoScroll) return; - - const behavior = firstRenderRef.current ? "instant" : "auto"; - firstRenderRef.current = false; - - isScrollingToBottomRef.current = true; - div.scrollTo({ top: div.scrollHeight, behavior }); - }; - - useOnResizeContent(divRef, () => { - if (!isScrollingToBottomRef.current && !useViewport.getState().isAtBottom) { - handleScroll(); - } else { - scrollToBottom(); - } - }); - - useOnScrollToBottom(() => { - scrollToBottom(); +>(({ autoScroll, onScroll, children, ...rest }, forwardedRef) => { + const autoScrollRef = useThreadViewportAutoScroll({ + autoScroll, }); - const handleScroll = () => { - const div = divRef.current; - if (!div) return; - - const isAtBottom = useViewport.getState().isAtBottom; - const newIsAtBottom = div.scrollHeight - div.scrollTop <= div.clientHeight; - - if (!newIsAtBottom && lastScrollTop.current < div.scrollTop) { - // ignore scroll down - } else if (newIsAtBottom !== isAtBottom) { - isScrollingToBottomRef.current = false; - (useViewport as unknown as StoreApi).setState({ - isAtBottom: newIsAtBottom, - }); - } - - lastScrollTop.current = div.scrollTop; - }; + const ref = useComposedRefs(forwardedRef, autoScrollRef); return ( - + {children} ); diff --git a/packages/react/src/utils/hooks/useManagedRef.ts b/packages/react/src/utils/hooks/useManagedRef.ts new file mode 100644 index 000000000..7713dbab0 --- /dev/null +++ b/packages/react/src/utils/hooks/useManagedRef.ts @@ -0,0 +1,24 @@ +import { useCallback, useRef } from "react"; + +export const useManagedRef = ( + callback: (node: TNode) => (() => void) | void, +) => { + const cleanupRef = useRef<(() => void) | void>(); + + const ref = useCallback( + (el: TNode | null) => { + // Call the previous cleanup function + if (cleanupRef.current) { + cleanupRef.current(); + } + + // Call the new callback and store its cleanup function + if (el) { + cleanupRef.current = callback(el); + } + }, + [callback], + ); + + return ref; +}; diff --git a/packages/react/src/utils/hooks/useOnResizeContent.tsx b/packages/react/src/utils/hooks/useOnResizeContent.tsx index 5b24d6d06..f4bd8de22 100644 --- a/packages/react/src/utils/hooks/useOnResizeContent.tsx +++ b/packages/react/src/utils/hooks/useOnResizeContent.tsx @@ -1,48 +1,49 @@ import { useCallbackRef } from "@radix-ui/react-use-callback-ref"; -import { type MutableRefObject, useEffect } from "react"; +import { useCallback } from "react"; +import { useManagedRef } from "./useManagedRef"; -export const useOnResizeContent = ( - ref: MutableRefObject, - callback: () => void, -) => { +export const useOnResizeContent = (callback: () => void) => { const callbackRef = useCallbackRef(callback); - useEffect(() => { - const el = ref.current; - if (!el) return; - - const resizeObserver = new ResizeObserver(() => { - callbackRef(); - }); - - const mutationObserver = new MutationObserver((mutations) => { - for (const mutation of mutations) { - for (const node of mutation.addedNodes) { - if (node instanceof Element) { - resizeObserver.observe(node); + + const refCallback = useCallback( + (el: HTMLElement) => { + const resizeObserver = new ResizeObserver(() => { + callbackRef(); + }); + + const mutationObserver = new MutationObserver((mutations) => { + for (const mutation of mutations) { + for (const node of mutation.addedNodes) { + if (node instanceof Element) { + resizeObserver.observe(node); + } } - } - for (const node of mutation.removedNodes) { - if (node instanceof Element) { - resizeObserver.unobserve(node); + for (const node of mutation.removedNodes) { + if (node instanceof Element) { + resizeObserver.unobserve(node); + } } } - } - callbackRef(); - }); + callbackRef(); + }); - resizeObserver.observe(el); - mutationObserver.observe(el, { childList: true }); + resizeObserver.observe(el); + mutationObserver.observe(el, { childList: true }); + + // Observe existing children + for (const child of el.children) { + resizeObserver.observe(child); + } - // Observe existing children - for (const child of el.children) { - resizeObserver.observe(child); - } + return () => { + resizeObserver.disconnect(); + mutationObserver.disconnect(); + }; + }, + [callbackRef], + ); - return () => { - resizeObserver.disconnect(); - mutationObserver.disconnect(); - }; - }, [ref, callbackRef]); + return useManagedRef(refCallback); };