Skip to content

Commit

Permalink
feat: useThreadModelConfig API (#1017)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Oct 15, 2024
1 parent 8eaabf3 commit 0edadd1
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .changeset/nervous-students-unite.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@assistant-ui/react": patch
---

feat: useThreadModelConfig API
1 change: 0 additions & 1 deletion packages/react/src/api/AssistantRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ export class AssistantRuntimeImpl
return this._core.registerModelConfigProvider(provider);
}

// TODO events for thread switching
/**
* @deprecated Thread is now static and never gets updated. This will be removed in 0.6.0.
*/
Expand Down
8 changes: 3 additions & 5 deletions packages/react/src/api/ThreadRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ThreadRuntimeCore,
SpeechState,
SubmittedFeedback,
ThreadRuntimeEventType,
} from "../runtimes/core/ThreadRuntimeCore";
import { ExportedMessageRepository } from "../runtimes/utils/MessageRepository";
import {
Expand Down Expand Up @@ -133,10 +134,7 @@ export type ThreadRuntime = {
*/
stopSpeaking: () => void;

unstable_on(
event: "switched-to" | "run-start",
callback: () => void,
): Unsubscribe;
unstable_on(event: ThreadRuntimeEventType, callback: () => void): Unsubscribe;

// Legacy methods with deprecations

Expand Down Expand Up @@ -501,7 +499,7 @@ export class ThreadRuntimeImpl
>();

public unstable_on(
event: "switched-to" | "run-start",
event: ThreadRuntimeEventType,
callback: () => void,
): Unsubscribe {
let subject = this._eventListenerNestedSubscriptions.get(event);
Expand Down
24 changes: 22 additions & 2 deletions packages/react/src/context/react/ThreadContext.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"use client";

import { createContext } from "react";
import { createContext, useEffect, useState } from "react";
import type { ThreadViewportState } from "../stores/ThreadViewport";
import { ReadonlyStore } from "../ReadonlyStore";
import { UseBoundStore } from "zustand";
import { createContextHook } from "./utils/createContextHook";
import { createContextStoreHook } from "./utils/createContextStoreHook";
import { ThreadRuntime } from "../../api";
import { ThreadState } from "../../api/ThreadRuntime";
import { ThreadMessage } from "../../types";
import { ModelConfig, ThreadMessage } from "../../types";
import { ThreadComposerState } from "../../api/ComposerRuntime";

export type ThreadContextValue = {
Expand Down Expand Up @@ -88,3 +88,23 @@ export const {
useViewport: useThreadViewport,
useViewportStore: useThreadViewportStore,
} = createContextStoreHook(useThreadContext, "useViewport");

export function useThreadModelConfig(options?: {
optional?: false | undefined;
}): ModelConfig;
export function useThreadModelConfig(options?: {
optional?: boolean | undefined;
}): ModelConfig | null;
export function useThreadModelConfig(options?: {
optional?: boolean | undefined;
}): ModelConfig | null {
const [, rerender] = useState({});

const runtime = useThreadRuntime(options);
useEffect(() => {
return runtime?.unstable_on("model-config-update", () => rerender({}));
}, [runtime]);

if (!runtime) return null;
return runtime?.getModelConfig();
}
1 change: 1 addition & 0 deletions packages/react/src/context/react/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export {
useThreadRuntime,
useThread,
useThreadComposer,
useThreadModelConfig,

/**
* @deprecated Use `useThread().messages` instead. This will be removed in 0.6.0.
Expand Down
11 changes: 8 additions & 3 deletions packages/react/src/runtimes/core/BaseThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
SpeechState,
RuntimeCapabilities,
SubmittedFeedback,
ThreadRuntimeEventType,
} from "../core/ThreadRuntimeCore";
import { DefaultEditComposerRuntimeCore } from "../composer/DefaultEditComposerRuntimeCore";
import { SpeechSynthesisAdapter } from "../speech";
Expand Down Expand Up @@ -52,7 +53,11 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {

public readonly composer = new DefaultThreadComposerRuntimeCore(this);

constructor(private configProvider: ModelConfigProvider) {}
constructor(private configProvider: ModelConfigProvider) {
this.configProvider.subscribe?.(() => {
this._notifyEventSubscribers("model-config-update");
});
}

public getModelConfig() {
return this.configProvider.getModelConfig();
Expand Down Expand Up @@ -94,7 +99,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {
for (const callback of this._subscriptions) callback();
}

public _notifyEventSubscribers(event: "switched-to" | "run-start") {
public _notifyEventSubscribers(event: ThreadRuntimeEventType) {
const subscribers = this._eventSubscribers.get(event);
if (!subscribers) return;

Expand Down Expand Up @@ -173,7 +178,7 @@ export abstract class BaseThreadRuntimeCore implements ThreadRuntimeCore {

private _eventSubscribers = new Map<string, Set<() => void>>();

public unstable_on(event: "switched-to" | "run-start", callback: () => void) {
public unstable_on(event: ThreadRuntimeEventType, callback: () => void) {
const subscribers = this._eventSubscribers.get(event);
if (!subscribers) {
this._eventSubscribers.set(event, new Set([callback]));
Expand Down
10 changes: 6 additions & 4 deletions packages/react/src/runtimes/core/ThreadRuntimeCore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ export type SubmittedFeedback = Readonly<{
type: "negative" | "positive";
}>;

export type ThreadRuntimeEventType =
| "switched-to"
| "run-start"
| "model-config-update";

export type ThreadRuntimeCore = Readonly<{
getMessageById: (messageId: string) =>
| {
Expand Down Expand Up @@ -86,8 +91,5 @@ export type ThreadRuntimeCore = Readonly<{
import(repository: ExportedMessageRepository): void;
export(): ExportedMessageRepository;

unstable_on(
event: "switched-to" | "run-start",
callback: () => void,
): Unsubscribe;
unstable_on(event: ThreadRuntimeEventType, callback: () => void): Unsubscribe;
}>;
6 changes: 5 additions & 1 deletion packages/react/src/types/ModelConfigTypes.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { z } from "zod";
import type { JSONSchema7 } from "json-schema";
import { Unsubscribe } from "./Unsubscribe";

export const LanguageModelV1CallSettingsSchema = z.object({
maxTokens: z.number().int().positive().optional(),
Expand Down Expand Up @@ -47,7 +48,10 @@ export type ModelConfig = {
config?: LanguageModelConfig | undefined;
};

export type ModelConfigProvider = { getModelConfig: () => ModelConfig };
export type ModelConfigProvider = {
getModelConfig: () => ModelConfig;
subscribe?: (callback: () => void) => Unsubscribe;
};

export const mergeModelConfigs = (
configSet: Set<ModelConfigProvider>,
Expand Down
18 changes: 17 additions & 1 deletion packages/react/src/utils/ProxyConfigProvider.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"use client";
import {
type ModelConfigProvider,
mergeModelConfigs,
Expand All @@ -13,8 +12,25 @@ export class ProxyConfigProvider implements ModelConfigProvider {

registerModelConfigProvider(provider: ModelConfigProvider) {
this._providers.add(provider);
const unsubscribe = provider.subscribe?.(() => {
this.notifySubscribers();
});
this.notifySubscribers();
return () => {
this._providers.delete(provider);
unsubscribe?.();
this.notifySubscribers();
};
}

private _subscribers = new Set<() => void>();

notifySubscribers() {
for (const callback of this._subscribers) callback();
}

subscribe(callback: () => void) {
this._subscribers.add(callback);
return () => this._subscribers.delete(callback);
}
}

0 comments on commit 0edadd1

Please sign in to comment.