From f693207ab381f8643c67abcef5539be74b08e5ff Mon Sep 17 00:00:00 2001 From: al Date: Tue, 30 Apr 2024 12:05:30 -0400 Subject: [PATCH] feat: persona templates and updates to options for and use of prompt formatting --- package.json | 2 +- src/inference.ts | 38 ++++++++++++++++++++++++++++---------- src/types.ts | 3 ++- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/package.json b/package.json index 23dc6e3..37442ae 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@libertai/libertai-js", - "version": "0.0.3", + "version": "0.0.4", "description": "In-browser SDK for interacting with LibertAI Decentralized AI Network", "keywords": [], "type": "module", diff --git a/src/inference.ts b/src/inference.ts index 298d6bd..8003204 100644 --- a/src/inference.ts +++ b/src/inference.ts @@ -24,12 +24,14 @@ export class LlamaCppApiEngine { * @param {Message[]} messages - The sequence of messages to generate an answer for * @param {Model} model - The model to use for inference * @param {Persona} persona - The persona to use for inference + * @param {string} targetUser - The user to generate the answer for, if different from Message[-1].role * @param {boolean} debug - Whether to print debug information */ async *generateAnswer( messages: Message[], model: Model, persona: Persona, + targetUser: string | null = null, debug: boolean = false ): AsyncGenerator<{ content: string; stopped: boolean }> { const maxTries = model.maxTries; @@ -39,7 +41,7 @@ export class LlamaCppApiEngine { ); // Prepare the prompt - const prompt = this.preparePrompt(messages, model, persona); + const prompt = this.preparePrompt(messages, model, persona, targetUser); if (debug) { console.log( @@ -137,34 +139,50 @@ export class LlamaCppApiEngine { private preparePrompt( messages: Message[], model: Model, - persona: Persona + persona: Persona, + // Allow caller to specify a target user, if different from Message[-1].role + targetUser: string | null = null ): string { let usedTokens = 0; const maxTokens = model.maxTokens; const promptFormat = model.promptFormat; + // Get the target user + if (targetUser === null) { + targetUser = messages[messages.length - 1].role; + } + + // Set {{char}} based on persona.name + // Set {{user}} based on targetUser + // Set {{model}} based on model.name + let description = persona.description; + description = description.replace(/\{\{char\}\}/g, persona.name); + description = description.replace(/\{\{user\}\}/g, targetUser); + description = description.replace(/\{\{model\}\}/g, model.name); + // Prepare our system prompt let systemPrompt = ''; - systemPrompt += `${promptFormat.userPrepend}system${promptFormat.lineSeparator}`; - systemPrompt += `You are ${persona.name}${promptFormat.lineSeparator}`; - systemPrompt += `${persona.description}${promptFormat.lineSeparator}`; - systemPrompt += `${promptFormat.stopSequence}${promptFormat.lineSeparator}`; + systemPrompt += `${promptFormat.userPrepend}system${promptFormat.userAppend}`; + systemPrompt += `${description}`; + systemPrompt += `${promptFormat.stopSequence}${promptFormat.logStart}`; + systemPrompt += `${promptFormat.lineSeparator}`; // Determine how many tokens we have left usedTokens = calculateTokenLength(systemPrompt); // Iterate over messagse in reverse order // to generate the chat log - let chatLog = `${promptFormat.userPrepend}${persona.name.toLowerCase()}${promptFormat.lineSeparator}`; + let chatLog = `${promptFormat.userPrepend}${persona.name.toLowerCase()}${promptFormat.userAppend}`; for (let i = messages.length - 1; i >= 0; i--) { const message = messages[i]; const timestamp_string = message.timestamp ? ` (at ${message.timestamp.toString()})` : ''; let messageLog = ''; - messageLog += `${promptFormat.userPrepend}${message.role.toLowerCase()}${timestamp_string}${promptFormat.lineSeparator}`; - messageLog += `${message.content}${promptFormat.lineSeparator}`; - messageLog += `${promptFormat.stopSequence}${promptFormat.lineSeparator}`; + messageLog += `${promptFormat.userPrepend}${message.role.toLowerCase()}${timestamp_string}${promptFormat.userAppend}`; + messageLog += `${message.content}`; + messageLog += `${promptFormat.stopSequence}`; + messageLog += `${promptFormat.lineSeparator}`; const messageTokens = calculateTokenLength(messageLog); if (usedTokens + messageTokens <= maxTokens) { diff --git a/src/types.ts b/src/types.ts index c10e081..834f3f2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -50,7 +50,8 @@ export interface PromptFormat { userAppend: string; // Character to separate messages lineSeparator: string; - + // Token to denote the start of any additional logs + logStart: string; // Default stop sequence for the model. This will be used to // generate prompts for the model stopSequence: string;