From 69828ab44a8515657a68364fb06bfbe1b040c679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Thu, 9 Jan 2025 17:52:10 +0100 Subject: [PATCH 1/4] feat(agents): improve template overriding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Radek Ježek --- examples/agents/bee_advanced.ts | 64 ++++++------------ src/agents/bee/agent.ts | 6 +- src/agents/bee/runners/base.ts | 24 ++++++- src/agents/bee/runners/default/runner.ts | 38 +++++------ src/agents/bee/runners/granite/prompts.ts | 81 ++++++++++------------- src/agents/bee/runners/granite/runner.ts | 42 +++++------- src/experimental/workflows/agent.ts | 12 ++-- src/template.test.ts | 31 ++++++--- src/template.ts | 11 ++- 9 files changed, 149 insertions(+), 160 deletions(-) diff --git a/examples/agents/bee_advanced.ts b/examples/agents/bee_advanced.ts index 562b090f..f57c1ef6 100644 --- a/examples/agents/bee_advanced.ts +++ b/examples/agents/bee_advanced.ts @@ -8,16 +8,6 @@ import { DuckDuckGoSearchToolSearchType, } from "bee-agent-framework/tools/search/duckDuckGoSearch"; import { OpenMeteoTool } from "bee-agent-framework/tools/weather/openMeteo"; -import { - BeeAssistantPrompt, - BeeSchemaErrorPrompt, - BeeSystemPrompt, - BeeToolErrorPrompt, - BeeToolInputErrorPrompt, - BeeToolNoResultsPrompt, - BeeUserEmptyPrompt, -} from "bee-agent-framework/agents/bee/prompts"; -import { PromptTemplate } from "bee-agent-framework/template"; import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat"; import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; import { z } from "zod"; @@ -32,40 +22,30 @@ const agent = new BeeAgent({ memory: new UnconstrainedMemory(), // You can override internal templates templates: { - user: new PromptTemplate({ - schema: z - .object({ - input: z.string(), - }) - .passthrough(), - template: `User: {{input}}`, - }), - system: BeeSystemPrompt.fork((old) => ({ - ...old, - defaults: { - instructions: "You are a helpful assistant that uses tools to answer questions.", - }, - })), - toolError: BeeToolErrorPrompt, - toolInputError: BeeToolInputErrorPrompt, - toolNoResultError: BeeToolNoResultsPrompt.fork((old) => ({ - ...old, - template: `${old.template}\nPlease reformat your input.`, - })), - toolNotFoundError: new PromptTemplate({ - schema: z - .object({ - tools: z.array(z.object({ name: z.string() }).passthrough()), - }) - .passthrough(), - template: `Tool does not exist! + user: (template) => + template.fork((config) => { + config.schema = z.object({ input: z.string() }).passthrough(); + config.template = `User: {{input}}`; + }), + system: (template) => + template.fork((config) => { + config.defaults.instructions = + "You are a helpful assistant that uses tools to answer questions."; + }), + toolNoResultError: (template) => + template.fork((config) => { + config.template += `\nPlease reformat your input.`; + }), + toolNotFoundError: (template) => + template.fork((config) => { + config.schema = z + .object({ tools: z.array(z.object({ name: z.string() }).passthrough()) }) + .passthrough(); + config.template = `Tool does not exist! {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} -{{/tools.length}}`, - }), - schemaError: BeeSchemaErrorPrompt, - assistant: BeeAssistantPrompt, - userEmpty: BeeUserEmptyPrompt, +{{/tools.length}}`; + }), }, tools: [ new DuckDuckGoSearchTool({ diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index 5ed76105..a83aca97 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -42,7 +42,11 @@ export interface BeeInput { tools: AnyTool[]; memory: BaseMemory; meta?: Omit; - templates?: Partial; + templates?: Partial<{ + [K in keyof BeeAgentTemplates]: + | BeeAgentTemplates[K] + | ((oldTemplate: BeeAgentTemplates[K]) => BeeAgentTemplates[K]); + }>; execution?: BeeAgentExecutionConfig; } diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index dc4ec42b..a4e095d0 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -31,6 +31,8 @@ import { shallowCopy } from "@/serializer/utils.js"; import { BaseMemory } from "@/memory/base.js"; import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; +import * as R from "remeda"; +import { PromptTemplate } from "@/template.js"; export interface BeeRunnerLLMInput { meta: BeeMeta; @@ -50,6 +52,8 @@ export abstract class BaseRunner extends Serializable { public readonly iterations: BeeAgentRunIteration[] = []; protected readonly failedAttemptsCounter: RetryCounter; + public templates: BeeAgentTemplates; + constructor( protected readonly input: BeeInput, protected readonly options: BeeRunOptions, @@ -57,6 +61,7 @@ export abstract class BaseRunner extends Serializable { ) { super(); this.failedAttemptsCounter = new RetryCounter(options?.execution?.totalMaxRetries, AgentError); + this.templates = this.resolveTemplates(); } async createIteration() { @@ -95,7 +100,24 @@ export abstract class BaseRunner extends Serializable { abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; - abstract get templates(): BeeAgentTemplates; + protected abstract get _defaultTemplates(): BeeAgentTemplates; + + protected resolveInputTemplate( + defaultTemplate: T, + update?: T | ((oldTemplate: T) => T), + ): T { + if (!update) { + return defaultTemplate; + } + return update instanceof PromptTemplate ? update : update(defaultTemplate); + } + + protected resolveTemplates(): BeeAgentTemplates { + const templatesUpdate = this.input.templates ?? {}; + return R.mapValues(this._defaultTemplates, (template, key) => + this.resolveInputTemplate(template, templatesUpdate[key] as typeof template | undefined), + ) as BeeAgentTemplates; + } protected abstract initMemory(input: BeeRunInput): Promise; diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index 4b786959..fe07d4be 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -15,12 +15,7 @@ */ import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js"; -import type { - BeeAgentRunIteration, - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, -} from "@/agents/bee/types.js"; +import type { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js"; import { Retryable } from "@/internals/helpers/retryable.js"; import { AgentError } from "@/agents/base.js"; import { @@ -48,6 +43,20 @@ import { Cache } from "@/cache/decoratorCache.js"; import { shallowCopy } from "@/serializer/utils.js"; export class DefaultRunner extends BaseRunner { + protected get _defaultTemplates() { + return { + system: BeeSystemPrompt, + assistant: BeeAssistantPrompt, + user: BeeUserPrompt, + schemaError: BeeSchemaErrorPrompt, + toolNotFoundError: BeeToolNotFoundPrompt, + toolError: BeeToolErrorPrompt, + toolInputError: BeeToolInputErrorPrompt, + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -369,23 +378,6 @@ export class DefaultRunner extends BaseRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - system: customTemplates.system ?? BeeSystemPrompt, - assistant: customTemplates.assistant ?? BeeAssistantPrompt, - user: customTemplates.user ?? BeeUserPrompt, - userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt, - toolError: customTemplates.toolError ?? BeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt, - toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt, - schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const parserRegex = isEmpty(tools) ? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`) diff --git a/src/agents/bee/runners/granite/prompts.ts b/src/agents/bee/runners/granite/prompts.ts index 97fa1cf5..5e11deb9 100644 --- a/src/agents/bee/runners/granite/prompts.ts +++ b/src/agents/bee/runners/granite/prompts.ts @@ -24,28 +24,20 @@ import { BeeUserPrompt, } from "@/agents/bee/prompts.js"; -export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => ({ - ...config, - template: `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`, -})); +export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => { + config.template = `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`; +}); -export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => ({ - ...config, - defaults: { - ...config.defaults, - instructions: "", - }, - functions: { - ...config.functions, - formatDate: function () { - const date = this.createdAt ? new Date(this.createdAt) : new Date(); - return new Intl.DateTimeFormat("en-US", { - dateStyle: "full", - timeStyle: "medium", - }).format(date); - }, - }, - template: `You are an AI assistant. +export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => { + config.defaults.instructions = ""; + config.functions.formatDate = function () { + const date = this.createdAt ? new Date(this.createdAt) : new Date(); + return new Intl.DateTimeFormat("en-US", { + dateStyle: "full", + timeStyle: "medium", + }).format(date); + }; + config.template = `You are an AI assistant. When the user sends a message figure out a solution and provide a final answer. {{#tools.length}} You have access to a set of tools that can be used to retrieve information and perform actions. @@ -85,38 +77,33 @@ You do not need a tool to get the current Date and Time. Use the information ava # Additional instructions {{.}} {{/instructions}} -`, -})); +`; +}); -export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => ({ - ...config, - template: `Error: The generated response does not adhere to the communication structure mentioned in the system prompt. -You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`, -})); +export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => { + config.template = `Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`; +}); -export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => ({ - ...config, - template: `{{input}}`, -})); +export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => { + config.template = `{{input}}`; +}); -export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => ({ - ...config, - template: `Tool does not exist! +export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => { + config.template = `Tool does not exist! {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} -{{/tools.length}}`, -})); +{{/tools.length}}`; +}); -export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => ({ - ...config, - template: `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. +export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => { + config.template = `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. -{{reason}}`, -})); +{{reason}}`; +}); -export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => ({ - ...config, - template: `{{reason}} +export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => { + config.template = `{{reason}} -HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`, -})); +HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`; +}); diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index e928ff34..799677bd 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -15,16 +15,11 @@ */ import { BaseMessage, Role } from "@/llms/primitives/message.js"; -import type { AnyTool } from "@/tools/base.js"; import { isEmpty } from "remeda"; +import type { AnyTool } from "@/tools/base.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { BaseMemory } from "@/memory/base.js"; -import type { - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, - BeeRunOptions, -} from "@/agents/bee/types.js"; +import type { BeeParserInput, BeeRunInput, BeeRunOptions } from "@/agents/bee/types.js"; import { BeeAgent, BeeInput } from "@/agents/bee/agent.js"; import type { GetRunContext } from "@/context.js"; import { @@ -36,9 +31,24 @@ import { GraniteBeeToolNotFoundPrompt, GraniteBeeUserPrompt, } from "@/agents/bee/runners/granite/prompts.js"; -import { Cache } from "@/cache/decoratorCache.js"; +import { BeeToolNoResultsPrompt, BeeUserEmptyPrompt } from "@/agents/bee/prompts.js"; export class GraniteRunner extends DefaultRunner { + protected get _defaultTemplates() { + return { + system: GraniteBeeSystemPrompt, + assistant: GraniteBeeAssistantPrompt, + user: GraniteBeeUserPrompt, + schemaError: GraniteBeeSchemaErrorPrompt, + toolNotFoundError: GraniteBeeToolNotFoundPrompt, + toolError: GraniteBeeToolErrorPrompt, + toolInputError: GraniteBeeToolInputErrorPrompt, + // Note: These are from bee + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -89,22 +99,6 @@ export class GraniteRunner extends DefaultRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - ...super.templates, - user: customTemplates.user ?? GraniteBeeUserPrompt, - system: customTemplates.system ?? GraniteBeeSystemPrompt, - assistant: customTemplates.assistant ?? GraniteBeeAssistantPrompt, - schemaError: customTemplates.schemaError ?? GraniteBeeSchemaErrorPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? GraniteBeeToolNotFoundPrompt, - toolError: customTemplates.toolError ?? GraniteBeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? GraniteBeeToolInputErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const { parser } = super.createParser(tools); diff --git a/src/experimental/workflows/agent.ts b/src/experimental/workflows/agent.ts index 7d7acd40..4f3b22fa 100644 --- a/src/experimental/workflows/agent.ts +++ b/src/experimental/workflows/agent.ts @@ -19,7 +19,6 @@ import { Workflow, WorkflowRunOptions } from "@/experimental/workflows/workflow. import { BaseMessage } from "@/llms/primitives/message.js"; import { AnyTool } from "@/tools/base.js"; import { AnyChatLLM } from "@/llms/chat.js"; -import { BeeSystemPrompt } from "@/agents/bee/prompts.js"; import { BaseMemory, ReadOnlyMemory } from "@/memory/base.js"; import { z } from "zod"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; @@ -100,13 +99,10 @@ export class AgentWorkflow { execution: input.execution, ...(input.instructions && { templates: { - system: BeeSystemPrompt.fork((config) => ({ - ...config, - defaults: { - ...config.defaults, - instructions: input.instructions || config.defaults.instructions, - }, - })), + system: (template) => + template.fork((config) => { + config.defaults.instructions = input.instructions || config.defaults.instructions; + }), }, }), }); diff --git a/src/template.test.ts b/src/template.test.ts index 4b495d38..c7b85c3e 100644 --- a/src/template.test.ts +++ b/src/template.test.ts @@ -15,7 +15,7 @@ */ import { PromptTemplateError, PromptTemplate, ValidationPromptTemplateError } from "@/template.js"; -import { z } from "zod"; +import { z, ZodType } from "zod"; describe("Prompt Template", () => { describe("Rendering", () => { @@ -187,7 +187,21 @@ describe("Prompt Template", () => { expect(cloned).toEqual(template); }); - it("Forks", () => { + it.each([ + (template: PromptTemplate) => + template.fork((config) => ({ + ...config, + template: "Hello {{name}}!", + customTags: ["{{", "}}"], + functions: { formatDate: () => "Today" }, + })), + (template: PromptTemplate) => + template.fork((config) => { + config.template = "Hello {{name}}!"; + config.customTags = ["{{", "}}"]; + config.functions.formatDate = () => "Today"; + }), + ])("Forks", (forkFn) => { const template = new PromptTemplate({ template: `Hello <>!`, schema: z.object({ @@ -196,13 +210,14 @@ describe("Prompt Template", () => { customTags: ["<<", ">>"], escape: false, }); - const forked = template.fork((config) => ({ - ...config, - template: "Hello {{name}}!", - customTags: ["{{", "}}"], - })); - + const forked = forkFn(template); expect(template.render({ name: "Tomas" })).toEqual(forked.render({ name: "Tomas" })); + // Configs are deeply copied + // @ts-expect-error protected property + const [templateConfig, forkedConfig] = [template.config, forked.config]; + expect(templateConfig.template).not.toEqual(forkedConfig.template); + expect(templateConfig.functions).not.toEqual(forkedConfig.functions); + expect(templateConfig).not.toEqual(forkedConfig); }); }); test("Custom function", () => { diff --git a/src/template.ts b/src/template.ts index 05f3dbcc..8ead72fd 100644 --- a/src/template.ts +++ b/src/template.ts @@ -17,13 +17,12 @@ import { FrameworkError } from "@/errors.js"; import { ObjectLike, PlainObject } from "@/internals/types.js"; import * as R from "remeda"; +import { pickBy } from "remeda"; import Mustache from "mustache"; import { Serializable } from "@/internals/serializable.js"; import { z, ZodType } from "zod"; import { createSchemaValidator, toJsonSchema } from "@/internals/helpers/schema.js"; import type { SchemaObject, ValidateFunction } from "ajv"; -import { shallowCopy } from "@/serializer/utils.js"; -import { pickBy } from "remeda"; import { getProp } from "@/internals/helpers/object.js"; type PostInfer = T extends PlainObject @@ -55,9 +54,9 @@ type PromptTemplateConstructor = N extends ZodType } : Omit, "schema"> & { schema: T | SchemaObject }; -type Customizer = ( - config: Required>, -) => PromptTemplateConstructor; +type Customizer = + | ((config: Required>) => PromptTemplateConstructor) + | ((config: Required>) => void); export class PromptTemplateError extends FrameworkError { template: PromptTemplate; @@ -122,7 +121,7 @@ export class PromptTemplate extends Serializable { fork( customizer: Customizer | Customizer, ): PromptTemplate { - const config = shallowCopy(this.config); + const config = R.clone(this.config); const newConfig = customizer?.(config) ?? config; return new PromptTemplate(newConfig); } From 975b6f6daee9454ac33c82228c54fa8de98df09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Fri, 10 Jan 2025 13:33:27 +0100 Subject: [PATCH 2/4] fixup! feat(agents): improve template overriding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Radek Ježek --- src/agents/bee/runners/base.ts | 15 +++++++-------- src/agents/bee/runners/default/runner.ts | 3 ++- src/agents/bee/runners/granite/runner.ts | 4 +++- src/template.test.ts | 14 ++++---------- src/template.ts | 12 +++++++----- 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index a4e095d0..5f5ac7e4 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -15,9 +15,9 @@ */ import { Serializable } from "@/internals/serializable.js"; +import type { BeeAgentTemplates } from "@/agents/bee/types.js"; import { BeeAgentRunIteration, - BeeAgentTemplates, BeeCallbacks, BeeIterationToolResult, BeeMeta, @@ -31,8 +31,9 @@ import { shallowCopy } from "@/serializer/utils.js"; import { BaseMemory } from "@/memory/base.js"; import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; -import * as R from "remeda"; import { PromptTemplate } from "@/template.js"; +import { Cache } from "@/cache/decoratorCache.js"; +import { mapValues } from "remeda"; export interface BeeRunnerLLMInput { meta: BeeMeta; @@ -52,8 +53,6 @@ export abstract class BaseRunner extends Serializable { public readonly iterations: BeeAgentRunIteration[] = []; protected readonly failedAttemptsCounter: RetryCounter; - public templates: BeeAgentTemplates; - constructor( protected readonly input: BeeInput, protected readonly options: BeeRunOptions, @@ -61,7 +60,6 @@ export abstract class BaseRunner extends Serializable { ) { super(); this.failedAttemptsCounter = new RetryCounter(options?.execution?.totalMaxRetries, AgentError); - this.templates = this.resolveTemplates(); } async createIteration() { @@ -100,7 +98,7 @@ export abstract class BaseRunner extends Serializable { abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; - protected abstract get _defaultTemplates(): BeeAgentTemplates; + protected abstract get defaultTemplates(): BeeAgentTemplates; protected resolveInputTemplate( defaultTemplate: T, @@ -112,9 +110,10 @@ export abstract class BaseRunner extends Serializable { return update instanceof PromptTemplate ? update : update(defaultTemplate); } - protected resolveTemplates(): BeeAgentTemplates { + @Cache({ enumerable: false }) + public get templates(): BeeAgentTemplates { const templatesUpdate = this.input.templates ?? {}; - return R.mapValues(this._defaultTemplates, (template, key) => + return mapValues(this.defaultTemplates, (template, key) => this.resolveInputTemplate(template, templatesUpdate[key] as typeof template | undefined), ) as BeeAgentTemplates; } diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index fe07d4be..b644811b 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -43,7 +43,8 @@ import { Cache } from "@/cache/decoratorCache.js"; import { shallowCopy } from "@/serializer/utils.js"; export class DefaultRunner extends BaseRunner { - protected get _defaultTemplates() { + @Cache({ enumerable: false }) + protected get defaultTemplates() { return { system: BeeSystemPrompt, assistant: BeeAssistantPrompt, diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index 799677bd..9185ccc8 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -32,9 +32,11 @@ import { GraniteBeeUserPrompt, } from "@/agents/bee/runners/granite/prompts.js"; import { BeeToolNoResultsPrompt, BeeUserEmptyPrompt } from "@/agents/bee/prompts.js"; +import { Cache } from "@/cache/decoratorCache.js"; export class GraniteRunner extends DefaultRunner { - protected get _defaultTemplates() { + @Cache({ enumerable: false }) + protected get defaultTemplates() { return { system: GraniteBeeSystemPrompt, assistant: GraniteBeeAssistantPrompt, diff --git a/src/template.test.ts b/src/template.test.ts index c7b85c3e..e80161f5 100644 --- a/src/template.test.ts +++ b/src/template.test.ts @@ -191,15 +191,14 @@ describe("Prompt Template", () => { (template: PromptTemplate) => template.fork((config) => ({ ...config, - template: "Hello {{name}}!", + template: "Hi {{name}}!", customTags: ["{{", "}}"], functions: { formatDate: () => "Today" }, })), (template: PromptTemplate) => template.fork((config) => { - config.template = "Hello {{name}}!"; + config.template = "Hi {{name}}!"; config.customTags = ["{{", "}}"]; - config.functions.formatDate = () => "Today"; }), ])("Forks", (forkFn) => { const template = new PromptTemplate({ @@ -211,13 +210,8 @@ describe("Prompt Template", () => { escape: false, }); const forked = forkFn(template); - expect(template.render({ name: "Tomas" })).toEqual(forked.render({ name: "Tomas" })); - // Configs are deeply copied - // @ts-expect-error protected property - const [templateConfig, forkedConfig] = [template.config, forked.config]; - expect(templateConfig.template).not.toEqual(forkedConfig.template); - expect(templateConfig.functions).not.toEqual(forkedConfig.functions); - expect(templateConfig).not.toEqual(forkedConfig); + expect(template.render({ name: "Tomas" })).toEqual("Hello Tomas!"); + expect(forked.render({ name: "Tomas" })).toEqual("Hi Tomas!"); }); }); test("Custom function", () => { diff --git a/src/template.ts b/src/template.ts index 8ead72fd..76ed3ab5 100644 --- a/src/template.ts +++ b/src/template.ts @@ -14,10 +14,9 @@ * limitations under the License. */ -import { FrameworkError } from "@/errors.js"; +import { FrameworkError, ValueError } from "@/errors.js"; import { ObjectLike, PlainObject } from "@/internals/types.js"; -import * as R from "remeda"; -import { pickBy } from "remeda"; +import { clone, identity, isPlainObject, pickBy } from "remeda"; import Mustache from "mustache"; import { Serializable } from "@/internals/serializable.js"; import { z, ZodType } from "zod"; @@ -121,8 +120,11 @@ export class PromptTemplate extends Serializable { fork( customizer: Customizer | Customizer, ): PromptTemplate { - const config = R.clone(this.config); + const config = clone(this.config); const newConfig = customizer?.(config) ?? config; + if (!isPlainObject(newConfig)) { + throw new ValueError("Return type from customizer must be a config or nothing."); + } return new PromptTemplate(newConfig); } @@ -146,7 +148,7 @@ export class PromptTemplate extends Serializable { { tags: this.config.customTags, ...(!this.config.escape && { - escape: R.identity(), + escape: identity(), }), }, ); From d9cf702a636f505ff4cf1367aa137ad764ba22cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Fri, 10 Jan 2025 16:11:58 +0100 Subject: [PATCH 3/4] fixup! fixup! feat(agents): improve template overriding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Radek Ježek --- src/agents/bee/runners/base.ts | 26 +++++++++++------------- src/agents/bee/runners/default/runner.ts | 2 +- src/agents/bee/runners/granite/runner.ts | 2 +- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index 5f5ac7e4..138b800c 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -98,24 +98,22 @@ export abstract class BaseRunner extends Serializable { abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; - protected abstract get defaultTemplates(): BeeAgentTemplates; - - protected resolveInputTemplate( - defaultTemplate: T, - update?: T | ((oldTemplate: T) => T), - ): T { - if (!update) { - return defaultTemplate; - } - return update instanceof PromptTemplate ? update : update(defaultTemplate); - } + public abstract get defaultTemplates(): BeeAgentTemplates; @Cache({ enumerable: false }) public get templates(): BeeAgentTemplates { const templatesUpdate = this.input.templates ?? {}; - return mapValues(this.defaultTemplates, (template, key) => - this.resolveInputTemplate(template, templatesUpdate[key] as typeof template | undefined), - ) as BeeAgentTemplates; + + return mapValues(this.defaultTemplates, (defaultTemplate, key) => { + if (!templatesUpdate[key]) { + return defaultTemplate; + } + if (templatesUpdate[key] instanceof PromptTemplate) { + return templatesUpdate[key]; + } + const update = templatesUpdate[key] as (template: typeof defaultTemplate) => typeof template; + return update(defaultTemplate); + }) as BeeAgentTemplates; } protected abstract initMemory(input: BeeRunInput): Promise; diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index b644811b..fb5a1dd6 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -44,7 +44,7 @@ import { shallowCopy } from "@/serializer/utils.js"; export class DefaultRunner extends BaseRunner { @Cache({ enumerable: false }) - protected get defaultTemplates() { + public get defaultTemplates() { return { system: BeeSystemPrompt, assistant: BeeAssistantPrompt, diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index 9185ccc8..2ac8b26d 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -36,7 +36,7 @@ import { Cache } from "@/cache/decoratorCache.js"; export class GraniteRunner extends DefaultRunner { @Cache({ enumerable: false }) - protected get defaultTemplates() { + public get defaultTemplates() { return { system: GraniteBeeSystemPrompt, assistant: GraniteBeeAssistantPrompt, From f13efaea84f345f9731e639c5f3ec02dee8ce7a9 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Fri, 10 Jan 2025 17:40:01 +0100 Subject: [PATCH 4/4] fix(templates): typescript types Signed-off-by: Tomas Dvorak --- src/agents/bee/agent.ts | 8 +++++--- src/agents/bee/runners/base.ts | 21 ++++++++------------- src/internals/helpers/object.ts | 11 +++++++++++ 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index a83aca97..ae677267 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -37,15 +37,17 @@ import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { ValueError } from "@/errors.js"; +export type BeeTemplateFactory = ( + template: BeeAgentTemplates[K], +) => BeeAgentTemplates[K]; + export interface BeeInput { llm: ChatLLM; tools: AnyTool[]; memory: BaseMemory; meta?: Omit; templates?: Partial<{ - [K in keyof BeeAgentTemplates]: - | BeeAgentTemplates[K] - | ((oldTemplate: BeeAgentTemplates[K]) => BeeAgentTemplates[K]); + [K in keyof BeeAgentTemplates]: BeeAgentTemplates[K] | BeeTemplateFactory; }>; execution?: BeeAgentExecutionConfig; } diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index 138b800c..a7a630c0 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -31,9 +31,9 @@ import { shallowCopy } from "@/serializer/utils.js"; import { BaseMemory } from "@/memory/base.js"; import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; -import { PromptTemplate } from "@/template.js"; import { Cache } from "@/cache/decoratorCache.js"; -import { mapValues } from "remeda"; +import { getProp, mapObj } from "@/internals/helpers/object.js"; +import { PromptTemplate } from "@/template.js"; export interface BeeRunnerLLMInput { meta: BeeMeta; @@ -102,18 +102,13 @@ export abstract class BaseRunner extends Serializable { @Cache({ enumerable: false }) public get templates(): BeeAgentTemplates { - const templatesUpdate = this.input.templates ?? {}; - - return mapValues(this.defaultTemplates, (defaultTemplate, key) => { - if (!templatesUpdate[key]) { - return defaultTemplate; - } - if (templatesUpdate[key] instanceof PromptTemplate) { - return templatesUpdate[key]; + return mapObj(this.defaultTemplates)((key, defaultTemplate) => { + const override = getProp(this.input.templates, [key], defaultTemplate); + if (override instanceof PromptTemplate) { + return override; } - const update = templatesUpdate[key] as (template: typeof defaultTemplate) => typeof template; - return update(defaultTemplate); - }) as BeeAgentTemplates; + return override(defaultTemplate) ?? defaultTemplate; + }); } protected abstract initMemory(input: BeeRunInput): Promise; diff --git a/src/internals/helpers/object.ts b/src/internals/helpers/object.ts index 7354501b..944de46c 100644 --- a/src/internals/helpers/object.ts +++ b/src/internals/helpers/object.ts @@ -147,3 +147,14 @@ export function customMerge>( } return finalResult; } + +export function mapObj(obj: T) { + return function (fn: (key: K, value: T[K]) => T[K]): T { + const updated: T = Object.assign({}, obj); + for (const pair of Object.entries(obj)) { + const [key, value] = pair as [K, T[K]]; + updated[key] = fn(key, value); + } + return updated; + }; +}