Skip to content

Commit

Permalink
update mma and oai.realtime to use ChatContext
Browse files Browse the repository at this point in the history
  • Loading branch information
nbsp committed Nov 1, 2024
1 parent 0d07e85 commit d72c5ba
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 16 deletions.
8 changes: 6 additions & 2 deletions agents/src/multimodal/multimodal_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,16 @@ export class MultimodalAgent extends EventEmitter {

constructor({
model,
chatCtx,
fncCtx,
}: {
model: RealtimeModel;
fncCtx?: llm.FunctionContext | undefined;
chatCtx?: llm.ChatContext;
fncCtx?: llm.FunctionContext;
}) {
super();
this.model = model;
this.#chatCtx = chatCtx;
this.#fncCtx = fncCtx;
}

Expand All @@ -83,6 +86,7 @@ export class MultimodalAgent extends EventEmitter {
#logger = log();
#session: RealtimeSession | null = null;
#fncCtx: llm.FunctionContext | undefined = undefined;
#chatCtx: llm.ChatContext | undefined = undefined;

#_started: boolean = false;
#_pendingFunctionCalls: Set<string> = new Set();
Expand Down Expand Up @@ -200,7 +204,7 @@ export class MultimodalAgent extends EventEmitter {
}
}

this.#session = this.model.session({ fncCtx: this.#fncCtx });
this.#session = this.model.session({ fncCtx: this.#fncCtx, chatCtx: this.#chatCtx });
this.#started = true;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
167 changes: 153 additions & 14 deletions plugins/openai/src/realtime/realtime_model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { AsyncIterableQueue, Future, Queue, llm, log, multimodal } from '@livekit/agents';
import {
AsyncIterableQueue,
Future,
Queue,
llm,
log,
mergeFrames,
multimodal,
} from '@livekit/agents';
import { ChatRole } from '@livekit/agents/src/llm/chat_context.js';
import { AudioFrame } from '@livekit/rtc-node';
import { once } from 'node:events';
import { WebSocket } from 'ws';
Expand Down Expand Up @@ -108,6 +117,7 @@ class InputAudioBuffer {

class ConversationItem {
#session: RealtimeSession;
#logger = log();

constructor(session: RealtimeSession) {
this.#session = session;
Expand All @@ -129,12 +139,127 @@ class ConversationItem {
});
}

create(item: api_proto.ConversationItemCreateContent, previousItemId?: string): void {
this.#session.queueMsg({
type: 'conversation.item.create',
item,
previous_item_id: previousItemId,
});
// create(item: api_proto.ConversationItemCreateContent, previousItemId?: string): void {
create(message: llm.ChatMessage, previousItemId?: string): void {
if (!message.content) {
return;
}

let event: api_proto.ConversationItemCreateEvent;

if (message.toolCallId) {
if (typeof message.content !== 'string') {
throw new TypeError('message.content must be a string');
}

event = {
type: 'conversation.item.create',
previous_item_id: previousItemId,
item: {
type: 'function_call_output',
call_id: message.toolCallId,
output: message.content,
},
};
} else {
let content = message.content;
if (!Array.isArray(content)) {
content = [content];
}

if (message.role === ChatRole.USER) {
const contents: (api_proto.InputTextContent | api_proto.InputAudioContent)[] = [];
for (const c of content) {
if (typeof c === 'string') {
contents.push({
type: 'input_text',
text: c,
});
} else if (
// typescript type guard for determining ChatAudio vs ChatImage
((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => {
return (c as llm.ChatAudio).frame !== undefined;
})(c)
) {
contents.push({
type: 'input_audio',
audio: Buffer.from(mergeFrames(c.frame).data.buffer).toString('base64'),
});
}
}

event = {
type: 'conversation.item.create',
previous_item_id: previousItemId,
item: {
type: 'message',
role: 'user',
content: contents,
},
};
} else if (message.role === ChatRole.ASSISTANT) {
const contents: api_proto.TextContent[] = [];
for (const c of content) {
if (typeof c === 'string') {
contents.push({
type: 'text',
text: c,
});
} else if (
// typescript type guard for determining ChatAudio vs ChatImage
((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => {
return (c as llm.ChatAudio).frame !== undefined;
})(c)
) {
this.#logger.warn('audio content in assistant message is not supported');
}
}

event = {
type: 'conversation.item.create',
previous_item_id: previousItemId,
item: {
type: 'message',
role: 'assistant',
content: contents,
},
};
} else if (message.role === ChatRole.SYSTEM) {
const contents: api_proto.InputTextContent[] = [];
for (const c of content) {
if (typeof c === 'string') {
contents.push({
type: 'input_text',
text: c,
});
} else if (
// typescript type guard for determining ChatAudio vs ChatImage
((c: llm.ChatAudio | llm.ChatImage): c is llm.ChatAudio => {
return (c as llm.ChatAudio).frame !== undefined;
})(c)
) {
this.#logger.warn('audio content in system message is not supported');
}
}

event = {
type: 'conversation.item.create',
previous_item_id: previousItemId,
item: {
type: 'message',
role: 'system',
content: contents,
},
};
} else {
this.#logger
.child({ message })
.warn('chat message is not supported inside the realtime API');
return;
}
}

this.#session.queueMsg(event);
}
}

Expand Down Expand Up @@ -302,6 +427,7 @@ export class RealtimeModel extends multimodal.RealtimeModel {

session({
fncCtx,
chatCtx,
modalities = this.#defaultOpts.modalities,
instructions = this.#defaultOpts.instructions,
voice = this.#defaultOpts.voice,
Expand All @@ -313,6 +439,7 @@ export class RealtimeModel extends multimodal.RealtimeModel {
maxResponseOutputTokens = this.#defaultOpts.maxResponseOutputTokens,
}: {
fncCtx?: llm.FunctionContext;
chatCtx?: llm.ChatContext;
modalities?: ['text', 'audio'] | ['text'];
instructions?: string;
voice?: api_proto.Voice;
Expand Down Expand Up @@ -341,7 +468,10 @@ export class RealtimeModel extends multimodal.RealtimeModel {
entraToken: this.#defaultOpts.entraToken,
};

const newSession = new RealtimeSession(opts, fncCtx);
const newSession = new RealtimeSession(opts, {
chatCtx: chatCtx || new llm.ChatContext(),
fncCtx,
});
this.#sessions.push(newSession);
return newSession;
}
Expand All @@ -352,6 +482,7 @@ export class RealtimeModel extends multimodal.RealtimeModel {
}

export class RealtimeSession extends multimodal.RealtimeSession {
#chatCtx: llm.ChatContext | undefined = undefined;
#fncCtx: llm.FunctionContext | undefined = undefined;
#opts: ModelOptions;
#pendingResponses: { [id: string]: RealtimeResponse } = {};
Expand All @@ -363,10 +494,14 @@ export class RealtimeSession extends multimodal.RealtimeSession {
#closing = true;
#sendQueue = new Queue<api_proto.ClientEvent>();

constructor(opts: ModelOptions, fncCtx?: llm.FunctionContext | undefined) {
constructor(
opts: ModelOptions,
{ fncCtx, chatCtx }: { fncCtx?: llm.FunctionContext; chatCtx?: llm.ChatContext },
) {
super();

this.#opts = opts;
this.#chatCtx = chatCtx;
this.#fncCtx = fncCtx;

this.#task = this.#start();
Expand All @@ -385,6 +520,10 @@ export class RealtimeSession extends multimodal.RealtimeSession {
});
}

get chatCtx(): llm.ChatContext | undefined {
return this.#chatCtx;
}

get fncCtx(): llm.FunctionContext | undefined {
return this.#fncCtx;
}
Expand Down Expand Up @@ -869,11 +1008,11 @@ export class RealtimeSession extends multimodal.RealtimeSession {
callId: item.call_id,
});
this.conversation.item.create(
{
type: 'function_call_output',
call_id: item.call_id,
output: content,
},
llm.ChatMessage.createToolFromFunctionResult({
name: item.name,
toolCallId: item.call_id,
result: content,
}),
output.itemId,
);
this.response.create();
Expand Down

0 comments on commit d72c5ba

Please sign in to comment.