Skip to content

Commit

Permalink
fix: avoid mutating construct options
Browse files Browse the repository at this point in the history
  • Loading branch information
Mahmoud Abughali committed Sep 27, 2024
1 parent f52e189 commit f80c33d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 64 deletions.
66 changes: 41 additions & 25 deletions src/tools/search/googleCustomSearch.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ describe("GoogleCustomSearch Tool", () => {
maxResultsPerPage: 10,
});

(googleSearchTool as any).client = mockCustomSearchClient;
Object.defineProperty(googleSearchTool, "client", {
get: () => mockCustomSearchClient,
});
});

const generateResults = (count: number) => {
Expand Down Expand Up @@ -71,14 +73,18 @@ describe("GoogleCustomSearch Tool", () => {

expect(response).toBeInstanceOf(GoogleSearchToolOutput);
expect(response.results.length).toBe(3);
expect(mockCustomSearchClient.cse.list).toHaveBeenCalledWith({
cx: "test-cse-id",
q: query,
auth: "test-api-key",
num: 10,
start: 1,
safe: "active",
});
expect(mockCustomSearchClient.cse.list).toHaveBeenCalledWith(
{
cx: "test-cse-id",
q: query,
num: 10,
start: 1,
safe: "active",
},
{
signal: undefined,
},
);
});

it("validates maxResultsPerPage range", () => {
Expand Down Expand Up @@ -117,22 +123,32 @@ describe("GoogleCustomSearch Tool", () => {
expect(combinedResults.length).toBe(20);

expect(mockCustomSearchClient.cse.list).toHaveBeenCalledTimes(2);
expect(mockCustomSearchClient.cse.list).toHaveBeenNthCalledWith(1, {
cx: "test-cse-id",
q: query,
auth: "test-api-key",
num: 10,
start: 1,
safe: "active",
});
expect(mockCustomSearchClient.cse.list).toHaveBeenNthCalledWith(2, {
cx: "test-cse-id",
q: query,
auth: "test-api-key",
num: 10,
start: 11,
safe: "active",
});
expect(mockCustomSearchClient.cse.list).toHaveBeenNthCalledWith(
1,
{
cx: "test-cse-id",
q: query,
num: 10,
start: 1,
safe: "active",
},
{
signal: undefined,
},
);
expect(mockCustomSearchClient.cse.list).toHaveBeenNthCalledWith(
2,
{
cx: "test-cse-id",
q: query,
num: 10,
start: 11,
safe: "active",
},
{
signal: undefined,
},
);
});

it("Serializes", async () => {
Expand Down
90 changes: 51 additions & 39 deletions src/tools/search/googleCustomSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import {
SearchToolResult,
SearchToolRunOptions,
} from "./base.js";
import { Tool, ToolInput, ToolError } from "@/tools/base.js";
import { Tool, ToolInput } from "@/tools/base.js";
import { z } from "zod";
import { Cache } from "@/cache/decoratorCache.js";
import { ValueError } from "@/errors.js";
import { ValidationError } from "ajv";
import { getEnv } from "@/internals/env.js";
import { parseEnv } from "@/internals/env.js";

export interface GoogleSearchToolOptions extends SearchToolOptions {
apiKey?: string;
cseId?: string;
maxResultsPerPage?: number;
maxResultsPerPage: number;
}

type GoogleSearchToolRunOptions = SearchToolRunOptions;
Expand Down Expand Up @@ -67,12 +68,10 @@ export class GoogleSearchTool extends Tool<
Useful for when you need to answer questions or find relevant content on all or specific websites.
Output is a list of relevant websites with descriptions.`;

protected readonly client: GoogleSearchAPI.Customsearch;

@Cache()
inputSchema() {
return z.object({
query: z.string({ description: `Search query` }).min(1).max(128),
query: z.string({ description: `Search query` }).min(1).max(2048),
page: z
.number()
.int()
Expand All @@ -86,60 +85,62 @@ export class GoogleSearchTool extends Tool<
});
}

public constructor(options: GoogleSearchToolOptions) {
protected apiKey: string | undefined;
protected cseId: string | undefined;

public constructor(options: GoogleSearchToolOptions = { maxResultsPerPage: 10 }) {
super(options);

const apiKey = options.apiKey || getEnv("GOOGLE_API_KEY");
const cseId = options.cseId || getEnv("GOOGLE_CSE_ID");
this.apiKey = options.apiKey || parseEnv("GOOGLE_API_KEY", z.string());
this.cseId = options.cseId || parseEnv("GOOGLE_CSE_ID", z.string());

if (!apiKey || !cseId) {
throw new ToolError(
`"apiKey" or "cseId" must both be provided. Either set them directly or put them in ENV ("GOOGLE_API_KEY" / "GOOGLE_CSE_ID")`,
[],
{ isFatal: true, isRetryable: false },
if (!this.apiKey || !this.cseId) {
throw new ValueError(
[
`"apiKey" or "cseId" must both be provided.`,
`Either set them directly or put them in ENV ("WATSONX_ACCESS_TOKEN" / "WATSONX_API_KEY")`,
].join("\n"),
);
}

this.options.maxResultsPerPage = options.maxResultsPerPage ?? 10;

if (this.options.maxResultsPerPage < 1 || this.options.maxResultsPerPage > 10) {
if (options.maxResultsPerPage < 1 || options.maxResultsPerPage > 10) {
throw new ValidationError([
{
message: "Property range must be between 1 and 10",
propertyName: "options.maxResultsPerPage",
},
]);
}
}

this.options.apiKey = apiKey;
this.options.cseId = cseId;
this.client = this._createClient();
@Cache()
protected get client(): GoogleSearchAPI.Customsearch {
return new GoogleSearchAPI.Customsearch({
auth: this.apiKey,
});
}

static {
this.register();
}

protected _createClient() {
return new GoogleSearchAPI.Customsearch({
auth: this.options.apiKey,
});
}

protected async _run(
{ query: input, page = 1 }: ToolInput<this>,
_options?: GoogleSearchToolRunOptions,
options?: GoogleSearchToolRunOptions,
) {
const startIndex = (page - 1) * this.options.maxResultsPerPage! + 1;

const response = await this.client.cse.list({
cx: this.options.cseId,
q: input,
auth: this.options.apiKey,
num: this.options.maxResultsPerPage,
start: startIndex,
safe: "active",
});
const startIndex = (page - 1) * this.options.maxResultsPerPage + 1;
const response = await this.client.cse.list(
{
cx: this.cseId,
q: input,
num: this.options.maxResultsPerPage,
start: startIndex,
safe: "active",
},
{
signal: options?.signal,
},
);

const results = response.data.items || [];

Expand All @@ -152,8 +153,19 @@ export class GoogleSearchTool extends Tool<
);
}

loadSnapshot(snapshot: ReturnType<typeof this.createSnapshot>): void {
createSnapshot() {
return {
...super.createSnapshot(),
apiKey: this.apiKey,
cseId: this.cseId,
};
}

loadSnapshot({ apiKey, cseId, ...snapshot }: ReturnType<typeof this.createSnapshot>) {
super.loadSnapshot(snapshot);
Object.assign(this, { client: this._createClient() });
Object.assign(this, {
apiKey,
cseId,
});
}
}

0 comments on commit f80c33d

Please sign in to comment.