Skip to content

Commit

Permalink
feat(openai): add LLM (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbsp authored Nov 8, 2024
1 parent 64b3e53 commit c3bc309
Show file tree
Hide file tree
Showing 18 changed files with 1,296 additions and 75 deletions.
6 changes: 6 additions & 0 deletions .changeset/soft-months-tickle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@livekit/agents": patch
"@livekit/agents-plugin-openai": minor
---

add OpenAI LLM
2 changes: 1 addition & 1 deletion agents/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"dependencies": {
"@livekit/mutex": "^1.0.0",
"@livekit/protocol": "^1.21.0",
"@livekit/rtc-node": "^0.11.0",
"@livekit/rtc-node": "^0.11.1",
"commander": "^12.0.0",
"livekit-server-sdk": "^2.6.1",
"pino": "^8.19.0",
Expand Down
29 changes: 20 additions & 9 deletions agents/src/llm/chat_context.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import type { AudioFrame } from '@livekit/rtc-node';
import type { CallableFunctionResult, FunctionContext } from './function_context.js';
import type { AudioFrame, VideoFrame } from '@livekit/rtc-node';
import type { CallableFunctionResult, FunctionCallInfo } from './function_context.js';

export enum ChatRole {
SYSTEM,
Expand All @@ -12,7 +12,7 @@ export enum ChatRole {
}

export interface ChatImage {
image: string | AudioFrame;
image: string | VideoFrame;
inferenceWidth?: number;
inferenceHeight?: number;
/**
Expand All @@ -39,7 +39,7 @@ export class ChatMessage {
readonly id?: string;
readonly name?: string;
readonly content?: ChatContent | ChatContent[];
readonly toolCalls?: FunctionContext;
readonly toolCalls?: FunctionCallInfo[];
readonly toolCallId?: string;
readonly toolException?: Error;

Expand All @@ -57,7 +57,7 @@ export class ChatMessage {
id?: string;
name?: string;
content?: ChatContent | ChatContent[];
toolCalls?: FunctionContext;
toolCalls?: FunctionCallInfo[];
toolCallId?: string;
toolException?: Error;
}) {
Expand All @@ -84,7 +84,7 @@ export class ChatMessage {
});
}

static createToolCalls(toolCalls: FunctionContext, text = '') {
static createToolCalls(toolCalls: FunctionCallInfo[], text = '') {
return new ChatMessage({
role: ChatRole.ASSISTANT,
toolCalls,
Expand Down Expand Up @@ -116,21 +116,32 @@ export class ChatMessage {

/** Returns a structured clone of this message. */
copy(): ChatMessage {
return structuredClone(this);
return new ChatMessage({
role: this.role,
id: this.id,
name: this.name,
content: this.content,
toolCalls: this.toolCalls,
toolCallId: this.toolCallId,
toolException: this.toolException,
});
}
}

export class ChatContext {
messages: ChatMessage[] = [];
metadata: { [id: string]: any } = {};

append(msg: { text?: string; images: ChatImage[]; role: ChatRole }): ChatContext {
append(msg: { text?: string; images?: ChatImage[]; role: ChatRole }): ChatContext {
this.messages.push(ChatMessage.create(msg));
return this;
}

/** Returns a structured clone of this context. */
copy(): ChatContext {
return structuredClone(this);
const ctx = new ChatContext();
ctx.messages.push(...this.messages.map((msg) => msg.copy()));
ctx.metadata = structuredClone(this.metadata);
return ctx;
}
}
3 changes: 2 additions & 1 deletion agents/src/llm/function_context.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
// SPDX-License-Identifier: Apache-2.0
import { describe, expect, it } from 'vitest';
import { z } from 'zod';
import { CallableFunction, oaiParams } from './function_context.js';
import type { CallableFunction } from './function_context.js';
import { oaiParams } from './function_context.js';

describe('function_context', () => {
describe('oaiParams', () => {
Expand Down
25 changes: 23 additions & 2 deletions agents/src/llm/function_context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ export interface CallableFunction<P extends z.ZodTypeAny = any, R = any> {
}

/** A function that has been called but is not yet running */
export interface DeferredFunction<P extends z.ZodTypeAny = any, R = any> {
export interface FunctionCallInfo<P extends z.ZodTypeAny = any, R = any> {
name: string;
func: CallableFunction<P, R>;
toolCallId: string;
rawParams: string;
params: inferParameters<P>;
task?: PromiseLike<CallableFunctionResult>;
}

/** A currently-running function call, called by the LLM. */
/** The result of a ran FunctionCallInfo. */
export interface CallableFunctionResult {
name: string;
toolCallId: string;
Expand Down Expand Up @@ -97,3 +98,23 @@ export const oaiParams = (p: z.AnyZodObject) => {
required: requiredProperties,
};
};

/** @internal */
export const oaiBuildFunctionInfo = (
fncCtx: FunctionContext,
toolCallId: string,
fncName: string,
rawArgs: string,
): FunctionCallInfo => {
if (!fncCtx[fncName]) {
throw new Error(`AI function ${fncName} not found`);
}

return {
name: fncName,
func: fncCtx[fncName],
toolCallId,
rawParams: rawArgs,
params: JSON.parse(rawArgs),
};
};
2 changes: 2 additions & 0 deletions agents/src/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
// SPDX-License-Identifier: Apache-2.0
export {
type CallableFunction,
type FunctionCallInfo,
type CallableFunctionResult,
type FunctionContext,
type inferParameters,
oaiParams,
oaiBuildFunctionInfo,
} from './function_context.js';

export {
Expand Down
26 changes: 11 additions & 15 deletions agents/src/llm/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
// SPDX-License-Identifier: Apache-2.0
import { AsyncIterableQueue } from '../utils.js';
import type { ChatContext, ChatRole } from './chat_context.js';
import type {
CallableFunctionResult,
DeferredFunction,
FunctionContext,
} from './function_context.js';
import type { FunctionCallInfo, FunctionContext } from './function_context.js';

export interface ChoiceDelta {
role: ChatRole;
content?: string;
toolCalls?: FunctionContext;
toolCalls?: FunctionCallInfo[];
}

export interface CompletionUsage {
Expand Down Expand Up @@ -54,19 +50,19 @@ export abstract class LLM {
export abstract class LLMStream implements AsyncIterableIterator<ChatChunk> {
protected queue = new AsyncIterableQueue<ChatChunk>();
protected closed = false;
protected _functionCalls: FunctionCallInfo[] = [];

#chatCtx: ChatContext;
#fncCtx?: FunctionContext;
#functionCalls: DeferredFunction[] = [];

constructor(chatCtx: ChatContext, fncCtx?: FunctionContext) {
this.#chatCtx = chatCtx;
this.#fncCtx = fncCtx;
}

/** List of called functions from this stream. */
get functionCalls(): DeferredFunction[] {
return this.#functionCalls;
get functionCalls(): FunctionCallInfo[] {
return this._functionCalls;
}

/** The function context of this stream. */
Expand All @@ -80,15 +76,15 @@ export abstract class LLMStream implements AsyncIterableIterator<ChatChunk> {
}

/** Execute all deferred functions of this stream concurrently. */
async executeFunctions(): Promise<CallableFunctionResult[]> {
return Promise.all(
this.#functionCalls.map((f) =>
f.func.execute(f.params).then(
executeFunctions(): FunctionCallInfo[] {
this._functionCalls.forEach(
(f) =>
(f.task = f.func.execute(f.params).then(
(result) => ({ name: f.name, toolCallId: f.toolCallId, result }),
(error) => ({ name: f.name, toolCallId: f.toolCallId, error }),
),
),
)),
);
return this._functionCalls;
}

next(): Promise<IteratorResult<ChatChunk>> {
Expand Down
2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"@livekit/agents-plugin-deepgram": "workspace:*",
"@livekit/agents-plugin-elevenlabs": "workspace:*",
"@livekit/agents-plugin-openai": "workspace:*",
"@livekit/rtc-node": "^0.11.0",
"@livekit/rtc-node": "^0.11.1",
"zod": "^3.23.8"
},
"version": null
Expand Down
2 changes: 1 addition & 1 deletion plugins/deepgram/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
"dependencies": {
"@livekit/agents": "workspace:*",
"@livekit/rtc-node": "^0.11.0",
"@livekit/rtc-node": "^0.11.1",
"ws": "^8.16.0"
}
}
2 changes: 1 addition & 1 deletion plugins/elevenlabs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
"dependencies": {
"@livekit/agents": "workspace:*",
"@livekit/rtc-node": "^0.11.0",
"@livekit/rtc-node": "^0.11.1",
"ws": "^8.16.0"
}
}
4 changes: 3 additions & 1 deletion plugins/openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
},
"dependencies": {
"@livekit/agents": "workspace:*",
"@livekit/rtc-node": "^0.11.0",
"@livekit/rtc-node": "^0.11.1",
"openai": "^4.70.2",
"sharp": "^0.33.5",
"ws": "^8.16.0"
}
}
2 changes: 2 additions & 0 deletions plugins/openai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
//
// SPDX-License-Identifier: Apache-2.0
export * as realtime from './realtime/index.js';
export * from './models.js';
export { type LLMOptions, LLM, LLMStream } from './llm.js';
Loading

0 comments on commit c3bc309

Please sign in to comment.