Skip to content

Commit

Permalink
feat: ThreadRuntime.Composer subscribe (#871)
Browse files Browse the repository at this point in the history
* feat: ThreadRuntime.Composer subscribe

* also patch playground
  • Loading branch information
Yonom authored Sep 22, 2024
1 parent ab6b2d8 commit 1a99132
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .changeset/tall-scissors-destroy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react": patch
"@assistant-ui/react-playground": patch
---

feat: ThreadRuntime.Composer subscribe
5 changes: 1 addition & 4 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ export class PlaygroundThreadRuntime implements ReactThreadRuntime {

private configProvider = new ProxyConfigProvider();

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);
public readonly composer = new ThreadRuntimeComposer(this);

constructor(
configProvider: ModelConfigProvider,
Expand Down
37 changes: 26 additions & 11 deletions packages/react/src/context/providers/ThreadProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { ThreadRuntimeWithSubscribe } from "../../runtimes/core/AssistantRuntime
import { makeThreadRuntimeStore } from "../stores/ThreadRuntime";
import { subscribeToMainThread } from "../../runtimes";
import { writableStore } from "../ReadonlyStore";
import { subscribeToMainThreadComposer } from "../../runtimes/core/subscribeToMainThread";

type ThreadProviderProps = {
provider: ThreadRuntimeWithSubscribe;
Expand Down Expand Up @@ -64,18 +65,8 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
}

const composerState = context.useComposer.getState();
if (
thread.composer.isEmpty !== composerState.isEmpty ||
thread.composer.text !== composerState.text ||
thread.composer.attachmentAccept !== composerState.attachmentAccept ||
thread.composer.attachments !== composerState.attachments ||
state.capabilities.cancel !== composerState.canCancel
) {
if (state.capabilities.cancel !== composerState.canCancel) {
writableStore(context.useComposer).setState({
isEmpty: thread.composer.isEmpty,
text: thread.composer.text,
attachmentAccept: thread.composer.attachmentAccept,
attachments: thread.composer.attachments,
canCancel: state.capabilities.cancel,
});
}
Expand All @@ -85,6 +76,30 @@ export const ThreadProvider: FC<PropsWithChildren<ThreadProviderProps>> = ({
return subscribeToMainThread(provider, onThreadUpdate);
}, [provider, context]);

useEffect(() => {
const onComposerUpdate = () => {
const composer = provider.thread.composer;

const composerState = context.useComposer.getState();
if (
composer.isEmpty !== composerState.isEmpty ||
composer.text !== composerState.text ||
composer.attachmentAccept !== composerState.attachmentAccept ||
composer.attachments !== composerState.attachments
) {
writableStore(context.useComposer).setState({
isEmpty: composer.isEmpty,
text: composer.text,
attachmentAccept: composer.attachmentAccept,
attachments: composer.attachments,
});
}
};

onComposerUpdate();
return subscribeToMainThreadComposer(provider, onComposerUpdate);
}, [provider, context]);

useInsertionEffect(
() =>
provider.subscribe(() => {
Expand Down
6 changes: 4 additions & 2 deletions packages/react/src/runtimes/core/AssistantRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ export type ThreadRuntimeWithSubscribe = {

export type AssistantRuntime = ThreadRuntimeWithSubscribe & {
switchToNewThread: () => void;

switchToThread(threadId: string): void;
/**
* @deprecated Use `switchToNewThread` instead. This will be removed in 0.6.0.
*/
switchToThread(threadId: null): void;
switchToThread(threadId: string): void;
switchToThread(threadId: string | null): void;

registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe;
};
2 changes: 2 additions & 0 deletions packages/react/src/runtimes/core/ThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,7 @@ export declare namespace ThreadRuntime {
reset: () => void;

send: () => void;

subscribe: (callback: () => void) => Unsubscribe;
}>;
}
19 changes: 19 additions & 0 deletions packages/react/src/runtimes/core/subscribeToMainThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,22 @@ export const subscribeToMainThread = (
cleanup?.();
};
};

export const subscribeToMainThreadComposer = (
runtime: ThreadRuntimeWithSubscribe,
callback: () => void,
) => {
let cleanup = runtime.thread.composer.subscribe(callback);
const inner = () => {
cleanup?.();
cleanup = runtime.thread.composer.subscribe(callback);

callback();
};

const unsubscribe = runtime.subscribe(inner);
return () => {
unsubscribe();
cleanup?.();
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ export class ExternalStoreThreadRuntime implements ReactThreadRuntime {

private _store!: ExternalStoreAdapter<any>;

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);
public readonly composer = new ThreadRuntimeComposer(this);

constructor(
private configProvider: ModelConfigProvider,
Expand Down
5 changes: 1 addition & 4 deletions packages/react/src/runtimes/local/LocalThreadRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ export class LocalThreadRuntime implements ThreadRuntime {
return this.repository.getMessages();
}

public readonly composer = new ThreadRuntimeComposer(
this,
this.notifySubscribers.bind(this),
);
public readonly composer = new ThreadRuntimeComposer(this);

constructor(
private configProvider: ModelConfigProvider,
Expand Down
13 changes: 11 additions & 2 deletions packages/react/src/runtimes/utils/ThreadRuntimeComposer.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ThreadComposerAttachment } from "../../context/stores/Attachment";
import { AppendMessage } from "../../types";
import { AppendMessage, Unsubscribe } from "../../types";
import { AttachmentAdapter } from "../attachment/AttachmentAdapter";
import { ThreadRuntime } from "../core";

Expand All @@ -17,7 +17,6 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
messages: ThreadRuntime["messages"];
append: (message: AppendMessage) => void;
},
private notifySubscribers: () => void,
) {}

public setAttachmentAdapter(adapter: AttachmentAdapter | undefined) {
Expand Down Expand Up @@ -94,4 +93,14 @@ export class ThreadRuntimeComposer implements ThreadRuntime.Composer {
});
this.reset();
}

private _subscriptions = new Set<() => void>();
private notifySubscribers() {
for (const callback of this._subscriptions) callback();
}

public subscribe(callback: () => void): Unsubscribe {
this._subscriptions.add(callback);
return () => this._subscriptions.delete(callback);
}
}

0 comments on commit 1a99132

Please sign in to comment.