From 0d753341437998339b7f100adf80f6866e209c42 Mon Sep 17 00:00:00 2001 From: Tarrence van As Date: Wed, 20 Nov 2024 21:35:01 -0500 Subject: [PATCH] Improve knowledge embeddings --- .../src/actions/summarize_conversation.ts | 7 +- packages/client-discord/src/messages.ts | 5 +- packages/client-github/src/index.ts | 45 +----- packages/core/src/generation.ts | 26 ++-- packages/core/src/index.ts | 1 + packages/core/src/knowledge.ts | 129 ++++++++++++++++++ packages/core/src/runtime.ts | 118 ++++++---------- packages/core/src/types.ts | 1 + .../plugin-bootstrap/src/evaluators/fact.ts | 2 +- packages/plugin-bootstrap/src/index.ts | 4 + 10 files changed, 201 insertions(+), 137 deletions(-) create mode 100644 packages/core/src/knowledge.ts diff --git a/packages/client-discord/src/actions/summarize_conversation.ts b/packages/client-discord/src/actions/summarize_conversation.ts index 86d3921dc5..2a78c85afe 100644 --- a/packages/client-discord/src/actions/summarize_conversation.ts +++ b/packages/client-discord/src/actions/summarize_conversation.ts @@ -251,7 +251,12 @@ const summarizeAction = { const model = models[runtime.character.settings.model]; const chunkSize = model.settings.maxContextLength - 1000; - const chunks = await splitChunks(formattedMemories, chunkSize, 0); + const chunks = await splitChunks( + formattedMemories, + chunkSize, + "gpt-4o-mini", + 0 + ); const datestr = new Date().toUTCString().replace(/:/g, "-"); diff --git a/packages/client-discord/src/messages.ts b/packages/client-discord/src/messages.ts index 63b53b44e1..6c4069c0d6 100644 --- a/packages/client-discord/src/messages.ts +++ b/packages/client-discord/src/messages.ts @@ -430,13 +430,13 @@ export class MessageManager { await this.runtime.messageManager.createMemory(memory); } - let state = (await this.runtime.composeState(userMessage, { + let state = await this.runtime.composeState(userMessage, { discordClient: this.client, discordMessage: message, agentName: this.runtime.character.name || this.client.user?.displayName, - })) as State; + }); if (!canSendMessage(message.channel).canSend) { return elizaLogger.warn( @@ -649,6 +649,7 @@ export class MessageManager { message: DiscordMessage ): Promise<{ processedContent: string; attachments: Media[] }> { let processedContent = message.content; + let attachments: Media[] = []; // Process code blocks in the message content diff --git a/packages/client-github/src/index.ts b/packages/client-github/src/index.ts index 17ad91a294..ac03a9df36 100644 --- a/packages/client-github/src/index.ts +++ b/packages/client-github/src/index.ts @@ -10,12 +10,8 @@ import { AgentRuntime, Client, IAgentRuntime, - Content, - Memory, + knowledge, stringToUuid, - embeddingZeroVector, - splitChunks, - embed, } from "@ai16z/eliza"; import { validateGithubConfig } from "./enviroment"; @@ -112,11 +108,8 @@ export class GitHubClient { relativePath ); - const memory: Memory = { + await knowledge.set(this.runtime, { id: knowledgeId, - agentId: this.runtime.agentId, - userId: this.runtime.agentId, - roomId: this.runtime.agentId, content: { text: content, hash: contentHash, @@ -128,39 +121,7 @@ export class GitHubClient { owner: this.config.owner, }, }, - embedding: embeddingZeroVector, - }; - - await this.runtime.documentsManager.createMemory(memory); - - // Only split if content exceeds 4000 characters - const fragments = - content.length > 4000 - ? await splitChunks(content, 2000, 200) - : [content]; - - for (const fragment of fragments) { - // Skip empty fragments - if (!fragment.trim()) continue; - - // Add file path context to the fragment before embedding - const fragmentWithPath = `File: ${relativePath}\n\n${fragment}`; - const embedding = await embed(this.runtime, fragmentWithPath); - - await this.runtime.knowledgeManager.createMemory({ - // We namespace the knowledge base uuid to avoid id - // collision with the document above. - id: stringToUuid(knowledgeId + fragment), - roomId: this.runtime.agentId, - agentId: this.runtime.agentId, - userId: this.runtime.agentId, - content: { - source: knowledgeId, - text: fragment, - }, - embedding, - }); - } + }); } } diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index a174fd011c..24cbf8fd38 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -463,34 +463,38 @@ export async function generateShouldRespond({ * Splits content into chunks of specified size with optional overlapping bleed sections * @param content - The text content to split into chunks * @param chunkSize - The maximum size of each chunk in tokens - * @param bleed - Number of characters to overlap between chunks (default: 100) * @param model - The model name to use for tokenization (default: runtime.model) + * @param bleed - Number of characters to overlap between chunks (default: 100) * @returns Promise resolving to array of text chunks with bleed sections */ export async function splitChunks( content: string, chunkSize: number, + model: string, bleed: number = 100 ): Promise { - const encoding = encoding_for_model("gpt-4o-mini"); - + const encoding = encoding_for_model(model as TiktokenModel); const tokens = encoding.encode(content); const chunks: string[] = []; const textDecoder = new TextDecoder(); for (let i = 0; i < tokens.length; i += chunkSize) { - const chunk = tokens.slice(i, i + chunkSize); - const decodedChunk = textDecoder.decode(encoding.decode(chunk)); + let chunk = tokens.slice(i, i + chunkSize); // Append bleed characters from the previous chunk - const startBleed = i > 0 ? content.slice(i - bleed, i) : ""; + if (i > 0) { + chunk = new Uint32Array([...tokens.slice(i - bleed, i), ...chunk]); + } + // Append bleed characters from the next chunk - const endBleed = - i + chunkSize < tokens.length - ? content.slice(i + chunkSize, i + chunkSize + bleed) - : ""; + if (i + chunkSize < tokens.length) { + chunk = new Uint32Array([ + ...chunk, + ...tokens.slice(i + chunkSize, i + chunkSize + bleed), + ]); + } - chunks.push(startBleed + decodedChunk + endBleed); + chunks.push(textDecoder.decode(encoding.decode(chunk))); } return chunks; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 425be803e0..96877411cc 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,3 +20,4 @@ export * from "./parsing.ts"; export * from "./uuid.ts"; export * from "./enviroment.ts"; export * from "./cache.ts"; +export { default as knowledge } from "./knowledge.ts"; diff --git a/packages/core/src/knowledge.ts b/packages/core/src/knowledge.ts new file mode 100644 index 0000000000..faee57b7a2 --- /dev/null +++ b/packages/core/src/knowledge.ts @@ -0,0 +1,129 @@ +import { UUID } from "crypto"; + +import { AgentRuntime } from "./runtime.ts"; +import { embed } from "./embedding.ts"; +import { Content, ModelClass, type Memory } from "./types.ts"; +import { stringToUuid } from "./uuid.ts"; +import { embeddingZeroVector } from "./memory.ts"; +import { splitChunks } from "./generation.ts"; +import { models } from "./models.ts"; +import elizaLogger from "./logger.ts"; + +async function get(runtime: AgentRuntime, message: Memory): Promise { + const processed = preprocess(message.content.text); + elizaLogger.log(`Querying knowledge for: ${processed}`); + const embedding = await embed(runtime, processed); + const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding( + embedding, + { + roomId: message.agentId, + agentId: message.agentId, + count: 3, + match_threshold: 0.1, + } + ); + + const uniqueSources = [ + ...new Set( + fragments.map((memory) => { + elizaLogger.log( + `Matched fragment: ${memory.content.text} with similarity: ${message.similarity}` + ); + return memory.content.source; + }) + ), + ]; + + const knowledgeDocuments = await Promise.all( + uniqueSources.map((source) => + runtime.documentsManager.getMemoryById(source as UUID) + ) + ); + + const knowledge = knowledgeDocuments + .filter((memory) => memory !== null) + .map((memory) => memory.content.text); + return knowledge; +} + +export type KnowledgeItem = { + id: UUID; + content: Content; +}; + +async function set(runtime: AgentRuntime, item: KnowledgeItem) { + await runtime.documentsManager.createMemory({ + embedding: embeddingZeroVector, + id: item.id, + agentId: runtime.agentId, + roomId: runtime.agentId, + userId: runtime.agentId, + createdAt: Date.now(), + content: item.content, + }); + + const preprocessed = preprocess(item.content.text); + const fragments = await splitChunks( + preprocessed, + 10, + models[runtime.character.modelProvider].model?.[ModelClass.EMBEDDING], + 5 + ); + + for (const fragment of fragments) { + const embedding = await embed(runtime, fragment); + await runtime.knowledgeManager.createMemory({ + // We namespace the knowledge base uuid to avoid id + // collision with the document above. + id: stringToUuid(item.id + fragment), + roomId: runtime.agentId, + agentId: runtime.agentId, + userId: runtime.agentId, + createdAt: Date.now(), + content: { + source: item.id, + text: fragment, + }, + embedding, + }); + } +} + +export function preprocess(content: string): string { + return ( + content + // Remove code blocks and their content + .replace(/```[\s\S]*?```/g, "") + // Remove inline code + .replace(/`.*?`/g, "") + // Convert headers to plain text with emphasis + .replace(/#{1,6}\s*(.*)/g, "$1") + // Remove image links but keep alt text + .replace(/!\[(.*?)\]\(.*?\)/g, "$1") + // Remove links but keep text + .replace(/\[(.*?)\]\(.*?\)/g, "$1") + // Remove HTML tags + .replace(/<[^>]*>/g, "") + // Remove horizontal rules + .replace(/^\s*[-*_]{3,}\s*$/gm, "") + // Remove comments + .replace(/\/\*[\s\S]*?\*\//g, "") + .replace(/\/\/.*/g, "") + // Normalize whitespace + .replace(/\s+/g, " ") + // Remove multiple newlines + .replace(/\n{3,}/g, "\n\n") + // strip all special characters + .replace(/[^a-zA-Z0-9\s]/g, "") + // Remove Discord mentions + .replace(/<@!?\d+>/g, "") + .trim() + .toLowerCase() + ); +} + +export default { + get, + set, + process, +}; diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts index cde94da023..1e12375fae 100644 --- a/packages/core/src/runtime.ts +++ b/packages/core/src/runtime.ts @@ -14,8 +14,8 @@ import { } from "./evaluators.ts"; import { generateText } from "./generation.ts"; import { formatGoalsAsString, getGoals } from "./goals.ts"; -import { elizaLogger, embed, splitChunks } from "./index.ts"; -import { embeddingZeroVector, MemoryManager } from "./memory.ts"; +import { elizaLogger } from "./index.ts"; +import { MemoryManager } from "./memory.ts"; import { formatActors, formatMessages, getActorDetails } from "./messages.ts"; import { parseJsonArrayFromText } from "./parsing.ts"; import { formatPosts } from "./posts.ts"; @@ -44,6 +44,7 @@ import { } from "./types.ts"; import { stringToUuid } from "./uuid.ts"; import { v4 as uuidv4 } from "uuid"; +import knowledge from "./knowledge.ts"; /** * Represents the runtime environment for an agent, handling message processing, @@ -222,11 +223,21 @@ export class AgentRuntime implements IAgentRuntime { opts.character?.id ?? opts?.agentId ?? stringToUuid(opts.character?.name ?? uuidv4()); + this.character = opts.character || defaultCharacter; + + // By convention, we create a user and room using the agent id. + // Memories related to it are considered global context for the agent. + this.ensureRoomExists(this.agentId); + this.ensureUserExists( + this.agentId, + this.character.name, + this.character.name + ); + this.ensureParticipantExists(this.agentId, this.agentId); elizaLogger.success("Agent ID", this.agentId); this.fetch = (opts.fetch as typeof fetch) ?? this.fetch; - this.character = opts.character || defaultCharacter; if (!opts.databaseAdapter) { throw new Error("No database adapter provided"); } @@ -348,60 +359,28 @@ export class AgentRuntime implements IAgentRuntime { * then chunks the content into fragments, embeds each fragment, and creates fragment memories. * @param knowledge An array of knowledge items containing id, path, and content. */ - private async processCharacterKnowledge(knowledge: string[]) { - // ensure the room exists and the agent exists in the room - await this.ensureRoomExists(this.agentId); - - await this.ensureUserExists( - this.agentId, - this.character.name, - this.character.name - ); - - await this.ensureParticipantExists(this.agentId, this.agentId); - - for (const knowledgeItem of knowledge) { - const knowledgeId = stringToUuid(knowledgeItem); + private async processCharacterKnowledge(items: string[]) { + for (const item of items) { + const knowledgeId = stringToUuid(item); const existingDocument = await this.documentsManager.getMemoryById(knowledgeId); - if (!existingDocument) { - elizaLogger.success( - "Processing knowledge for ", - this.character.name, - " - ", - knowledgeItem.slice(0, 100) - ); - await this.documentsManager.createMemory({ - embedding: embeddingZeroVector, - id: knowledgeId, - agentId: this.agentId, - roomId: this.agentId, - userId: this.agentId, - createdAt: Date.now(), - content: { - text: knowledgeItem, - }, - }); - - const fragments = await splitChunks(knowledgeItem, 1200, 200); - for (const fragment of fragments) { - const embedding = await embed(this, fragment); - await this.knowledgeManager.createMemory({ - // We namespace the knowledge base uuid to avoid id - // collision with the document above. - id: stringToUuid(knowledgeId + fragment), - roomId: this.agentId, - agentId: this.agentId, - userId: this.agentId, - createdAt: Date.now(), - content: { - source: knowledgeId, - text: fragment, - }, - embedding, - }); - } + if (existingDocument) { + return; } + + console.log( + "Processing knowledge for ", + this.character.name, + " - ", + item.slice(0, 100) + ); + + await knowledge.set(this, { + id: knowledgeId, + content: { + text: item, + }, + }); } } @@ -935,33 +914,8 @@ Text: ${attachment.text} .join(" "); } - async function getKnowledge( - runtime: AgentRuntime, - message: Memory - ): Promise { - const embedding = await embed(runtime, message.content.text); - - const memories = - await runtime.knowledgeManager.searchMemoriesByEmbedding( - embedding, - { - roomId: message.agentId, - agentId: message.agentId, - count: 3, - } - ); - - const knowledge = memories.map((memory) => memory.content.text); - - return knowledge; - } - - const formatKnowledge = (knowledge: string[]) => { - return knowledge.map((knowledge) => `- ${knowledge}`).join("\n"); - }; - const formattedKnowledge = formatKnowledge( - await getKnowledge(this, message) + await knowledge.get(this, message) ); const initialState = { @@ -1243,3 +1197,7 @@ Text: ${attachment.text} } as State; } } + +const formatKnowledge = (knowledge: string[]) => { + return knowledge.map((knowledge) => `- ${knowledge}`).join("\n"); +}; diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 32eab8378e..59247b10fb 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -175,6 +175,7 @@ export interface Memory { embedding?: number[]; // An optional embedding vector representing the semantic content of the memory. roomId: UUID; // The room or conversation ID associated with the memory. unique?: boolean; // Whether the memory is unique or not + similarity?: number; // embedding match similarity } /** diff --git a/packages/plugin-bootstrap/src/evaluators/fact.ts b/packages/plugin-bootstrap/src/evaluators/fact.ts index 97c22b6e1a..15857f3d11 100644 --- a/packages/plugin-bootstrap/src/evaluators/fact.ts +++ b/packages/plugin-bootstrap/src/evaluators/fact.ts @@ -13,7 +13,7 @@ import { export const formatFacts = (facts: Memory[]) => { const messageStrings = facts .reverse() - .map((fact: Memory) => `${(fact.content as Content)?.content}`); + .map((fact: Memory) => fact.content.text); const finalMessageStrings = messageStrings.join("\n"); return finalMessageStrings; }; diff --git a/packages/plugin-bootstrap/src/index.ts b/packages/plugin-bootstrap/src/index.ts index ea004debce..22de71d068 100644 --- a/packages/plugin-bootstrap/src/index.ts +++ b/packages/plugin-bootstrap/src/index.ts @@ -12,6 +12,10 @@ import { boredomProvider } from "./providers/boredom.ts"; import { factsProvider } from "./providers/facts.ts"; import { timeProvider } from "./providers/time.ts"; +export * as actions from "./actions"; +export * as evaluators from "./evaluators"; +export * as providers from "./providers"; + export const bootstrapPlugin: Plugin = { name: "bootstrap", description: "Agent bootstrap with basic actions and evaluators",