Skip to content

Commit

Permalink
feat: MessagePrimitive.Attachments (#785)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Sep 9, 2024
1 parent c845fcf commit 3b0f20b
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .changeset/slow-pianos-fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: MessagePrimitive.Attachments
77 changes: 77 additions & 0 deletions packages/react/src/context/providers/MessageAttachmentProvider.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"use client";

import { type FC, type PropsWithChildren, useEffect, useState } from "react";
import { create } from "zustand";
import type { MessageState } from "../stores";
import { useMessageContext } from "../react";
import { MessageAttachmentState } from "../stores/Attachment";
import {
AttachmentContext,
AttachmentContextValue,
} from "../react/AttachmentContext";
import { writableStore } from "../ReadonlyStore";

type MessageAttachmentProviderProps = PropsWithChildren<{
attachmentIndex: number;
}>;

const getAttachment = (
{ message }: MessageState,
useAttachment: AttachmentContextValue["useAttachment"] | undefined,
partIndex: number,
) => {
if (message.role !== "user") return null;

const attachments = message.attachments;
let attachment = attachments[partIndex];
if (!attachment) return null;

// if the attachment is the same, don't update
const currentState = useAttachment?.getState();
if (currentState && currentState.attachment === attachment) return null;

return Object.freeze({ attachment });
};

const useMessageAttachmentContext = (partIndex: number) => {
const { useMessage } = useMessageContext();
const [context] = useState<AttachmentContextValue & { type: "message" }>(
() => {
const useAttachment = create<MessageAttachmentState>(
() => getAttachment(useMessage.getState(), undefined, partIndex)!,
);

return { type: "message", useAttachment };
},
);

useEffect(() => {
const syncAttachment = (messageState: MessageState) => {
const newState = getAttachment(
messageState,
context.useAttachment,
partIndex,
);
if (!newState) return;
writableStore(context.useAttachment).setState(newState, true);
};

syncAttachment(useMessage.getState());
return useMessage.subscribe(syncAttachment);
}, [context, useMessage, partIndex]);

return context;
};

export const MessageAttachmentProvider: FC<MessageAttachmentProviderProps> = ({
attachmentIndex: partIndex,
children,
}) => {
const context = useMessageAttachmentContext(partIndex);

return (
<AttachmentContext.Provider value={context}>
{children}
</AttachmentContext.Provider>
);
};
88 changes: 88 additions & 0 deletions packages/react/src/primitives/message/MessageAttachments.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"use client";

import { ComponentType, type FC, memo } from "react";
import { useMessageContext } from "../../context";
import { useAttachmentContext } from "../../context/react/AttachmentContext";
import { MessageAttachmentProvider } from "../../context/providers/MessageAttachmentProvider";
import type { MessageAttachment } from "../../context/stores/Attachment";

export type MessagePrimitiveAttachmentsProps = {
components:
| {
Image?: ComponentType | undefined;
Document?: ComponentType | undefined;
File?: ComponentType | undefined;
Attachment?: ComponentType | undefined;
}
| undefined;
};

const getComponent = (
components: MessagePrimitiveAttachmentsProps["components"],
attachment: MessageAttachment,
) => {
const type = attachment.type;
switch (type) {
case "image":
return components?.Image ?? components?.Attachment;
case "document":
return components?.Document ?? components?.Attachment;
case "file":
return components?.File ?? components?.Attachment;
default:
const _exhaustiveCheck: never = type;
throw new Error(`Unknown attachment type: ${_exhaustiveCheck}`);
}
};

const AttachmentComponent: FC<{
components: MessagePrimitiveAttachmentsProps["components"];
}> = ({ components }) => {
const { useAttachment } = useAttachmentContext({ type: "message" });
const Component = useAttachment((a) =>
getComponent(components, a.attachment),
);

if (!Component) return null;
return <Component />;
};

const MessageAttachmentImpl: FC<
MessagePrimitiveAttachmentsProps & { attachmentIndex: number }
> = ({ components, attachmentIndex }) => {
return (
<MessageAttachmentProvider attachmentIndex={attachmentIndex}>
<AttachmentComponent components={components} />
</MessageAttachmentProvider>
);
};

const MessageAttachment = memo(
MessageAttachmentImpl,
(prev, next) =>
prev.attachmentIndex === next.attachmentIndex &&
prev.components?.Image === next.components?.Image &&
prev.components?.Document === next.components?.Document &&
prev.components?.File === next.components?.File &&
prev.components?.Attachment === next.components?.Attachment,
);

export const MessagePrimitiveAttachments: FC<
MessagePrimitiveAttachmentsProps
> = ({ components }) => {
const { useMessage } = useMessageContext();
const attachmentsCount = useMessage(({ message }) => {
if (message.role !== "user") return 0;
return message.attachments.length;
});

return Array.from({ length: attachmentsCount }, (_, index) => (
<MessageAttachment
key={index}
attachmentIndex={index}
components={components}
/>
));
};

MessagePrimitiveAttachments.displayName = "MessagePrimitive.Attachments";
1 change: 1 addition & 0 deletions packages/react/src/primitives/message/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export { MessagePrimitiveRoot as Root } from "./MessageRoot";
export { MessagePrimitiveIf as If } from "./MessageIf";
export { MessagePrimitiveContent as Content } from "./MessageContent";
export { MessagePrimitiveInProgress as InProgress } from "./MessageInProgress";
export { MessagePrimitiveAttachments as Attachments } from "./MessageAttachments";

0 comments on commit 3b0f20b

Please sign in to comment.