Skip to content

Commit

Permalink
feat: MessageState.submittedFeedback state (#972)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Oct 11, 2024
1 parent 899b963 commit 8c80f2a
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 47 deletions.
5 changes: 5 additions & 0 deletions .changeset/silver-suns-destroy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: MessageState.submittedFeedback state
6 changes: 5 additions & 1 deletion packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore {
public readonly capabilities = CAPABILITIES;
public readonly extras = undefined;
public readonly suggestions: readonly ThreadSuggestion[] = [];
public readonly speech = null;
public readonly speech = undefined;
public readonly adapters = undefined;

private configProvider = new ProxyConfigProvider();
Expand Down Expand Up @@ -282,6 +282,10 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore {
throw new Error("PlaygroundRuntime does not support speaking.");
}

public getSubmittedFeedback() {
return undefined;
}

public submitFeedback(): never {
throw new Error("PlaygroundRuntime does not support feedback.");
}
Expand Down
8 changes: 6 additions & 2 deletions packages/react/src/api/MessageRuntime.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { SpeechState } from "../runtimes/core/ThreadRuntimeCore";
import {
SpeechState,
SubmittedFeedback,
} from "../runtimes/core/ThreadRuntimeCore";
import {
ThreadMessage,
ThreadAssistantContentPart,
Expand Down Expand Up @@ -104,7 +107,8 @@ export type MessageState = ThreadMessage & {
branchNumber: number;
branchCount: number;

speech: SpeechState | null;
speech: SpeechState | undefined;
submittedFeedback: SubmittedFeedback | undefined;
};

export type MessageStateBinding = SubscribableWithState<MessageState>;
Expand Down
27 changes: 21 additions & 6 deletions packages/react/src/api/ThreadRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
SubmitFeedbackOptions,
ThreadRuntimeCore,
SpeechState,
SubmittedFeedback,
} from "../runtimes/core/ThreadRuntimeCore";
import { ExportedMessageRepository } from "../runtimes/utils/MessageRepository";
import {
Expand Down Expand Up @@ -73,7 +74,7 @@ export type ThreadState = Readonly<{
messages: readonly ThreadMessage[];
suggestions: readonly ThreadSuggestion[];
extras: unknown;
speech: SpeechState | null;
speech: SpeechState | undefined;
}>;

export const getThreadState = (runtime: ThreadRuntimeCore): ThreadState => {
Expand All @@ -92,6 +93,7 @@ export const getThreadState = (runtime: ThreadRuntimeCore): ThreadState => {
speech: runtime.speech,
});
};

export type ThreadRuntime = {
composer: ThreadComposerRuntime;
getState(): ThreadState;
Expand Down Expand Up @@ -146,7 +148,7 @@ export type ThreadRuntime = {
/**
* @deprecated Use `getState().speechState` instead. This will be removed in 0.6.0.
*/
speech: SpeechState | null;
speech: SpeechState | undefined;

/**
* @deprecated Use `getState().extras` instead. This will be removed in 0.6.0.
Expand All @@ -173,6 +175,11 @@ export type ThreadRuntime = {
*/
speak: (messageId: string) => void;

/**
* @deprecated Use `getMesssageById(id).getState().submittedFeedback` instead. This will be removed in 0.6.0.
*/
getSubmittedFeedback: (messageId: string) => SubmittedFeedback | undefined;

/**
* @deprecated Use `getMesssageById(id).submitFeedback({ type })` instead. This will be removed in 0.6.0.
*/
Expand Down Expand Up @@ -336,6 +343,10 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime {
return this._threadBinding.getState().stopSpeaking();
}

public getSubmittedFeedback(messageId: string) {
return this._threadBinding.getState().getSubmittedFeedback(messageId);
}

/**
* @deprecated Use `getMesssageById(id).submitFeedback({ type })` instead. This will be removed in 0.6.0.
*/
Expand Down Expand Up @@ -375,9 +386,10 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime {
const message = messages[idx];
if (!message) return SKIP_UPDATE;

const branches = this._threadBinding
.getState()
.getBranches(message.id);
const thread = this._threadBinding.getState();

const branches = thread.getBranches(message.id);
const submittedFeedback = thread.getSubmittedFeedback(message.id);

return {
...message,
Expand All @@ -390,7 +402,10 @@ export class ThreadRuntimeImpl implements ThreadRuntimeCore, ThreadRuntime {
branchNumber: branches.indexOf(message.id) + 1,
branchCount: branches.length,

speech: speechState?.messageId === message.id ? speechState : null,
speech:
speechState?.messageId === message.id ? speechState : undefined,

submittedFeedback,
} satisfies MessageState;
},
subscribe: (callback) => this._threadBinding.subscribe(callback),
Expand Down
9 changes: 0 additions & 9 deletions packages/react/src/context/stores/MessageUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ export type MessageUtilsState = Readonly<{

isHovering: boolean;
setIsHovering: (value: boolean) => void;

/** @deprecated This will be moved to `useMessage().submittedFeedback`. This will be removed in 0.6.0. */
submittedFeedback: "positive" | "negative" | null;
/** @deprecated This will be moved to `useMessageRuntime().submitFeedback()` instead. This will be removed in 0.6.0. */
setSubmittedFeedback: (feedback: "positive" | "negative" | null) => void;
}>;

export const makeMessageUtilsStore = () =>
Expand All @@ -24,9 +19,5 @@ export const makeMessageUtilsStore = () =>
setIsHovering: (value) => {
set({ isHovering: value });
},
submittedFeedback: null,
setSubmittedFeedback: (feedback) => {
set({ submittedFeedback: feedback });
},
};
});
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import { useCallback } from "react";
import { useMessageRuntime, useMessageUtilsStore } from "../../context";
import { useMessageRuntime } from "../../context";

export const useActionBarFeedbackNegative = () => {
const messageRuntime = useMessageRuntime();
const messageUtilsStore = useMessageUtilsStore();

const callback = useCallback(() => {
messageRuntime.submitFeedback({
type: "negative",
});
messageUtilsStore.getState().setSubmittedFeedback("negative");
}, [messageUtilsStore, messageRuntime]);
messageRuntime.submitFeedback({ type: "negative" });
}, [messageRuntime]);

return callback;
};
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import { useCallback } from "react";
import { useMessageRuntime, useMessageUtilsStore } from "../../context";
import { useMessageRuntime } from "../../context";

export const useActionBarFeedbackPositive = () => {
const messageRuntime = useMessageRuntime();
const messageUtilsStore = useMessageUtilsStore();

const callback = useCallback(() => {
messageRuntime.submitFeedback({
type: "positive",
});
messageUtilsStore.getState().setSubmittedFeedback("positive");
}, [messageUtilsStore, messageRuntime]);
messageRuntime.submitFeedback({ type: "positive" });
}, [messageRuntime]);

return callback;
};
6 changes: 3 additions & 3 deletions packages/react/src/primitive-hooks/message/useMessageIf.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ export const useMessageIf = (props: UseMessageIfProps) => {
return useCombinedStore(
[messageStore, messageUtilsStore],
(
{ role, attachments, branchCount, isLast, speech },
{ isCopied, isHovering, submittedFeedback },
{ role, attachments, branchCount, isLast, speech, submittedFeedback },
{ isCopied, isHovering },
) => {
if (props.hasBranches === true && branchCount < 2) return false;

Expand Down Expand Up @@ -57,7 +57,7 @@ export const useMessageIf = (props: UseMessageIfProps) => {

if (
props.submittedFeedback !== undefined &&
submittedFeedback !== props.submittedFeedback
( submittedFeedback?.type ?? null) !== props.submittedFeedback
)
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { forwardRef } from "react";
import { useActionBarFeedbackNegative } from "../../primitive-hooks/actionBar/useActionBarFeedbackNegative";
import { ActionButtonProps } from "../../utils/createActionButton";
import { composeEventHandlers } from "@radix-ui/primitive";
import { useMessageUtils } from "../../context";
import { useMessage } from "../../context";
import { Primitive } from "@radix-ui/react-primitive";

/**
Expand All @@ -22,8 +22,8 @@ export const ActionBarPrimitiveFeedbackNegative = forwardRef<
ActionBarPrimitiveFeedbackNegative.Element,
ActionBarPrimitiveFeedbackNegative.Props
>(({ onClick, disabled, ...props }, forwardedRef) => {
const isSubmitted = useMessageUtils(
(u) => u.submittedFeedback === "negative",
const isSubmitted = useMessage(
(u) => u.submittedFeedback?.type === "negative",
);
const callback = useActionBarFeedbackNegative();
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { forwardRef } from "react";
import { useActionBarFeedbackPositive } from "../../primitive-hooks/actionBar/useActionBarFeedbackPositive";
import { ActionButtonProps } from "../../utils/createActionButton";
import { composeEventHandlers } from "@radix-ui/primitive";
import { useMessageUtils } from "../../context";
import { useMessage } from "../../context";
import { Primitive } from "@radix-ui/react-primitive";

/**
Expand All @@ -22,8 +22,8 @@ export const ActionBarPrimitiveFeedbackPositive = forwardRef<
ActionBarPrimitiveFeedbackPositive.Element,
ActionBarPrimitiveFeedbackPositive.Props
>(({ onClick, disabled, ...props }, forwardedRef) => {
const isSubmitted = useMessageUtils(
(u) => u.submittedFeedback === "positive",
const isSubmitted = useMessage(
(u) => u.submittedFeedback?.type === "positive",
);
const callback = useActionBarFeedbackPositive();
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export const useActionBarFloatStatus = ({
// floating status
if (
autohideFloat === "always" ||
(autohideFloat === "single-branch" && m.branches.length <= 1)
(autohideFloat === "single-branch" && m.branchCount <= 1)
)
return HideAndFloatStatus.Floating;

Expand Down
15 changes: 11 additions & 4 deletions packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
ThreadRuntimeCore,
SpeechState,
RuntimeCapabilities,
SubmittedFeedback,
} from "../core/ThreadRuntimeCore";
import { DefaultEditComposerRuntimeCore } from "../composer/DefaultEditComposerRuntimeCore";
import { SpeechSynthesisAdapter } from "../speech";
Expand Down Expand Up @@ -93,17 +94,23 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {
return () => this._subscriptions.delete(callback);
}

private _submittedFeedback: Record<string, SubmittedFeedback> = {};

public getSubmittedFeedback(messageId: string) {
return this._submittedFeedback[messageId];
}

public submitFeedback({ messageId, type }: SubmitFeedbackOptions) {
const adapter = this.adapters?.feedback;
if (!adapter) throw new Error("Feedback adapter not configured");

const { message } = this.repository.getMessage(messageId);
adapter.submit({ message, type });
this._submittedFeedback[messageId] = { type };
}

// TODO speech runtime?
private _stopSpeaking: Unsubscribe | undefined;
public speech: SpeechState | null = null;
public speech: SpeechState | undefined;

public speak(messageId: string) {
const adapter = this.adapters?.speech;
Expand All @@ -117,7 +124,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {
const unsub = utterance.subscribe(() => {
if (utterance.status.type === "ended") {
this._stopSpeaking = undefined;
this.speech = null;
this.speech = undefined;
} else {
this.speech = { messageId, status: utterance.status };
}
Expand All @@ -128,7 +135,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {
this._stopSpeaking = () => {
utterance.cancel();
unsub();
this.speech = null;
this.speech = undefined;
this._stopSpeaking = undefined;
};
}
Expand Down
7 changes: 6 additions & 1 deletion packages/react/src/runtimes/core/ThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ export type SpeechState = Readonly<{
status: SpeechSynthesisAdapter.Status;
}>;

export type SubmittedFeedback = Readonly<{
type: "negative" | "positive";
}>;

export type ThreadRuntimeCore = Readonly<{
getBranches: (messageId: string) => readonly string[];
switchToBranch: (branchId: string) => void;
Expand All @@ -52,6 +56,7 @@ export type ThreadRuntimeCore = Readonly<{
speak: (messageId: string) => void;
stopSpeaking: () => void;

getSubmittedFeedback: (messageId: string) => SubmittedFeedback | undefined;
submitFeedback: (feedback: SubmitFeedbackOptions) => void;

getModelConfig: () => ModelConfig;
Expand All @@ -60,7 +65,7 @@ export type ThreadRuntimeCore = Readonly<{
getEditComposer: (messageId: string) => ComposerRuntimeCore | undefined;
beginEdit: (messageId: string) => void;

speech: SpeechState | null;
speech: SpeechState | undefined;

capabilities: Readonly<RuntimeCapabilities>;
threadId: string;
Expand Down

0 comments on commit 8c80f2a

Please sign in to comment.