Skip to content

Commit

Permalink
fixup! feat(agents): improve template overriding
Browse files Browse the repository at this point in the history
Signed-off-by: Radek Ježek <[email protected]>
  • Loading branch information
jezekra1 committed Jan 10, 2025
1 parent ddc142a commit 95c2424
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 24 deletions.
13 changes: 6 additions & 7 deletions src/agents/bee/runners/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import { Serializable } from "@/internals/serializable.js";
import {
BeeAgentRunIteration,
BeeAgentTemplates,
BeeCallbacks,
BeeIterationToolResult,
BeeMeta,
Expand All @@ -33,6 +32,8 @@ import { GetRunContext } from "@/context.js";
import { Emitter } from "@/emitter/emitter.js";
import * as R from "remeda";
import { PromptTemplate } from "@/template.js";
import type { BeeAgentTemplates } from "@/agents/bee/types.js";
import { Cache } from "@/cache/decoratorCache.js";

export interface BeeRunnerLLMInput {
meta: BeeMeta;
Expand All @@ -52,16 +53,13 @@ export abstract class BaseRunner extends Serializable {
public readonly iterations: BeeAgentRunIteration[] = [];
protected readonly failedAttemptsCounter: RetryCounter;

public templates: BeeAgentTemplates;

constructor(
protected readonly input: BeeInput,
protected readonly options: BeeRunOptions,
protected readonly run: GetRunContext<BeeAgent>,
) {
super();
this.failedAttemptsCounter = new RetryCounter(options?.execution?.totalMaxRetries, AgentError);
this.templates = this.resolveTemplates();
}

async createIteration() {
Expand Down Expand Up @@ -100,7 +98,7 @@ export abstract class BaseRunner extends Serializable {

abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>;

protected abstract get _defaultTemplates(): BeeAgentTemplates;
protected abstract get defaultTemplates(): BeeAgentTemplates;

protected resolveInputTemplate<T extends BeeAgentTemplates[keyof BeeAgentTemplates]>(
defaultTemplate: T,
Expand All @@ -112,9 +110,10 @@ export abstract class BaseRunner extends Serializable {
return update instanceof PromptTemplate ? update : update(defaultTemplate);
}

protected resolveTemplates(): BeeAgentTemplates {
@Cache({ enumerable: false })
public get templates(): BeeAgentTemplates {
const templatesUpdate = this.input.templates ?? {};
return R.mapValues(this._defaultTemplates, (template, key) =>
return R.mapValues(this.defaultTemplates, (template, key) =>
this.resolveInputTemplate(template, templatesUpdate[key] as typeof template | undefined),
) as BeeAgentTemplates;
}
Expand Down
3 changes: 2 additions & 1 deletion src/agents/bee/runners/default/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import { Cache } from "@/cache/decoratorCache.js";
import { shallowCopy } from "@/serializer/utils.js";

export class DefaultRunner extends BaseRunner {
protected get _defaultTemplates() {
@Cache({ enumerable: false })
protected get defaultTemplates() {
return {
system: BeeSystemPrompt,
assistant: BeeAssistantPrompt,
Expand Down
4 changes: 3 additions & 1 deletion src/agents/bee/runners/granite/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ import {
GraniteBeeUserPrompt,
} from "@/agents/bee/runners/granite/prompts.js";
import { BeeToolNoResultsPrompt, BeeUserEmptyPrompt } from "@/agents/bee/prompts.js";
import { Cache } from "@/cache/decoratorCache.js";

export class GraniteRunner extends DefaultRunner {
protected get _defaultTemplates() {
@Cache({ enumerable: false })
protected get defaultTemplates() {
return {
system: GraniteBeeSystemPrompt,
assistant: GraniteBeeAssistantPrompt,
Expand Down
14 changes: 4 additions & 10 deletions src/template.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,14 @@ describe("Prompt Template", () => {
<T extends ZodType>(template: PromptTemplate<T>) =>
template.fork((config) => ({
...config,
template: "Hello {{name}}!",
template: "Hi {{name}}!",
customTags: ["{{", "}}"],
functions: { formatDate: () => "Today" },
})),
<T extends ZodType>(template: PromptTemplate<T>) =>
template.fork((config) => {
config.template = "Hello {{name}}!";
config.template = "Hi {{name}}!";
config.customTags = ["{{", "}}"];
config.functions.formatDate = () => "Today";
}),
])("Forks", (forkFn) => {
const template = new PromptTemplate({
Expand All @@ -211,13 +210,8 @@ describe("Prompt Template", () => {
escape: false,
});
const forked = forkFn(template);
expect(template.render({ name: "Tomas" })).toEqual(forked.render({ name: "Tomas" }));
// Configs are deeply copied
// @ts-expect-error protected property
const [templateConfig, forkedConfig] = [template.config, forked.config];
expect(templateConfig.template).not.toEqual(forkedConfig.template);
expect(templateConfig.functions).not.toEqual(forkedConfig.functions);
expect(templateConfig).not.toEqual(forkedConfig);
expect(template.render({ name: "Tomas" })).toEqual("Hello Tomas!");
expect(forked.render({ name: "Tomas" })).toEqual("Hi Tomas!");
});
});
test("Custom function", () => {
Expand Down
12 changes: 7 additions & 5 deletions src/template.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
* limitations under the License.
*/

import { FrameworkError } from "@/errors.js";
import { FrameworkError, ValueError } from "@/errors.js";
import { ObjectLike, PlainObject } from "@/internals/types.js";
import * as R from "remeda";
import { pickBy } from "remeda";
import { clone, identity, isPlainObject, pickBy } from "remeda";
import Mustache from "mustache";
import { Serializable } from "@/internals/serializable.js";
import { z, ZodType } from "zod";
Expand Down Expand Up @@ -121,8 +120,11 @@ export class PromptTemplate<T extends ZodType> extends Serializable {
fork<R extends ZodType>(
customizer: Customizer<T, SchemaObject> | Customizer<T, R>,
): PromptTemplate<T | R> {
const config = R.clone(this.config);
const config = clone(this.config);
const newConfig = customizer?.(config) ?? config;
if (!isPlainObject(newConfig)) {
throw new ValueError("Return type from customizer must be a config or nothing.");
}
return new PromptTemplate(newConfig);
}

Expand All @@ -146,7 +148,7 @@ export class PromptTemplate<T extends ZodType> extends Serializable {
{
tags: this.config.customTags,
...(!this.config.escape && {
escape: R.identity(),
escape: identity(),
}),
},
);
Expand Down

0 comments on commit 95c2424

Please sign in to comment.