Skip to content

Commit

Permalink
feat: persona templates and updates to options for and use of prompt …
Browse files Browse the repository at this point in the history
…formatting
  • Loading branch information
amiller68 committed Apr 30, 2024
1 parent 762de20 commit f693207
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@libertai/libertai-js",
"version": "0.0.3",
"version": "0.0.4",
"description": "In-browser SDK for interacting with LibertAI Decentralized AI Network",
"keywords": [],
"type": "module",
Expand Down
38 changes: 28 additions & 10 deletions src/inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ export class LlamaCppApiEngine {
* @param {Message[]} messages - The sequence of messages to generate an answer for
* @param {Model} model - The model to use for inference
* @param {Persona} persona - The persona to use for inference
* @param {string} targetUser - The user to generate the answer for, if different from Message[-1].role
* @param {boolean} debug - Whether to print debug information
*/
async *generateAnswer(
messages: Message[],
model: Model,
persona: Persona,
targetUser: string | null = null,
debug: boolean = false
): AsyncGenerator<{ content: string; stopped: boolean }> {
const maxTries = model.maxTries;
Expand All @@ -39,7 +41,7 @@ export class LlamaCppApiEngine {
);

// Prepare the prompt
const prompt = this.preparePrompt(messages, model, persona);
const prompt = this.preparePrompt(messages, model, persona, targetUser);

if (debug) {
console.log(
Expand Down Expand Up @@ -137,34 +139,50 @@ export class LlamaCppApiEngine {
private preparePrompt(
messages: Message[],
model: Model,
persona: Persona
persona: Persona,
// Allow caller to specify a target user, if different from Message[-1].role
targetUser: string | null = null
): string {
let usedTokens = 0;
const maxTokens = model.maxTokens;
const promptFormat = model.promptFormat;

// Get the target user
if (targetUser === null) {
targetUser = messages[messages.length - 1].role;
}

// Set {{char}} based on persona.name
// Set {{user}} based on targetUser
// Set {{model}} based on model.name
let description = persona.description;
description = description.replace(/\{\{char\}\}/g, persona.name);
description = description.replace(/\{\{user\}\}/g, targetUser);
description = description.replace(/\{\{model\}\}/g, model.name);

// Prepare our system prompt
let systemPrompt = '';
systemPrompt += `${promptFormat.userPrepend}system${promptFormat.lineSeparator}`;
systemPrompt += `You are ${persona.name}${promptFormat.lineSeparator}`;
systemPrompt += `${persona.description}${promptFormat.lineSeparator}`;
systemPrompt += `${promptFormat.stopSequence}${promptFormat.lineSeparator}`;
systemPrompt += `${promptFormat.userPrepend}system${promptFormat.userAppend}`;
systemPrompt += `${description}`;
systemPrompt += `${promptFormat.stopSequence}${promptFormat.logStart}`;
systemPrompt += `${promptFormat.lineSeparator}`;

// Determine how many tokens we have left
usedTokens = calculateTokenLength(systemPrompt);

// Iterate over messagse in reverse order
// to generate the chat log
let chatLog = `${promptFormat.userPrepend}${persona.name.toLowerCase()}${promptFormat.lineSeparator}`;
let chatLog = `${promptFormat.userPrepend}${persona.name.toLowerCase()}${promptFormat.userAppend}`;
for (let i = messages.length - 1; i >= 0; i--) {
const message = messages[i];
const timestamp_string = message.timestamp
? ` (at ${message.timestamp.toString()})`
: '';
let messageLog = '';
messageLog += `${promptFormat.userPrepend}${message.role.toLowerCase()}${timestamp_string}${promptFormat.lineSeparator}`;
messageLog += `${message.content}${promptFormat.lineSeparator}`;
messageLog += `${promptFormat.stopSequence}${promptFormat.lineSeparator}`;
messageLog += `${promptFormat.userPrepend}${message.role.toLowerCase()}${timestamp_string}${promptFormat.userAppend}`;
messageLog += `${message.content}`;
messageLog += `${promptFormat.stopSequence}`;
messageLog += `${promptFormat.lineSeparator}`;

const messageTokens = calculateTokenLength(messageLog);
if (usedTokens + messageTokens <= maxTokens) {
Expand Down
3 changes: 2 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ export interface PromptFormat {
userAppend: string;
// Character to separate messages
lineSeparator: string;

// Token to denote the start of any additional logs
logStart: string;
// Default stop sequence for the model. This will be used to
// generate prompts for the model
stopSequence: string;
Expand Down

0 comments on commit f693207

Please sign in to comment.