Skip to content

Commit

Permalink
feat(tools): reduce number of iterations needed for SQLTool (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
abughali authored Nov 18, 2024
1 parent 33f1db5 commit cc221a3
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/tools/database/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ describe("SQLTool", () => {
vi.spyOn(tool as any, "connection").mockResolvedValue(mockSequelize as unknown as Sequelize);

const result = await tool.run({
action: "QUERY",
query: "SELECT * FROM users;",
});

Expand All @@ -84,6 +85,7 @@ describe("SQLTool", () => {
});

const result = await tool.run({
action: "QUERY",
query: "DELETE FROM users;",
});

Expand Down Expand Up @@ -111,6 +113,7 @@ describe("SQLTool", () => {
vi.spyOn(tool as any, "connection").mockResolvedValue(mockSequelize as unknown as Sequelize);

const result = await tool.run({
action: "QUERY",
query: "SELECT * FROM users;",
});

Expand Down
87 changes: 66 additions & 21 deletions src/tools/database/sql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,45 @@ import {
BaseToolOptions,
BaseToolRunOptions,
JSONToolOutput,
ToolInputValidationError,
} from "@/tools/base.js";
import { z } from "zod";
import { Sequelize, Options } from "sequelize";
import { Provider, getMetadata } from "@/tools/database/metadata.js";
import { Cache } from "@/cache/decoratorCache.js";
import { ValidationError } from "ajv";
import { AnyToolSchemaLike } from "@/internals/helpers/schema.js";

interface ToolOptions extends BaseToolOptions {
provider: Provider;
connection: Options;
}

type ToolRunOptions = BaseToolRunOptions;

export const SQLToolAction = {
GetMetadata: "GET_METADATA",
Query: "QUERY",
} as const;

export class SQLTool extends Tool<JSONToolOutput<any>, ToolOptions, ToolRunOptions> {
name = "SQLTool";

description =
"Converts natural language to SQL query and executes it using the provided tables schema.";
description = `Converts natural language to SQL query and executes it. IMPORTANT: strictly follow this order of actions:
1. ${SQLToolAction.GetMetadata} - get database tables structure (metadata)
2. ${SQLToolAction.Query} - execute the generated SQL query`;

inputSchema() {
return z.object({
query: z.string({ description: "The SQL query to be executed." }).min(1),
action: z
.nativeEnum(SQLToolAction)
.describe(
`The action to perform. ${SQLToolAction.GetMetadata} get database tables structure, ${SQLToolAction.Query} execute the SQL query`,
),
query: z
.string()
.optional()
.describe(`The SQL query to be executed, required for ${SQLToolAction.Query} action`),
});
}

Expand Down Expand Up @@ -77,6 +94,18 @@ export class SQLTool extends Tool<JSONToolOutput<any>, ToolOptions, ToolRunOptio
}
}

protected validateInput(
schema: AnyToolSchemaLike,
input: unknown,
): asserts input is ToolInput<this> {
super.validateInput(schema, input);
if (input.action === SQLToolAction.Query && !input.query) {
throw new ToolInputValidationError(
`SQL Query is required for ${SQLToolAction.Query} action.`,
);
}
}

static {
this.register();
}
Expand All @@ -97,8 +126,29 @@ export class SQLTool extends Tool<JSONToolOutput<any>, ToolOptions, ToolRunOptio
}

protected async _run(
{ query }: ToolInput<this>,
_options?: ToolRunOptions,
input: ToolInput<this>,
_options: ToolRunOptions | undefined,
): Promise<JSONToolOutput<any>> {
const { provider, connection } = this.options;
const { schema } = connection;

if (input.action === SQLToolAction.GetMetadata) {
const sequelize = await this.connection();
const metadata = await getMetadata(sequelize, provider, schema);
return new JSONToolOutput(metadata);
}

if (input.action === SQLToolAction.Query) {
return await this.executeQuery(input.query!, provider, schema);
}

throw new ToolError(`Invalid action specified: ${input.action}`);
}

protected async executeQuery(
query: string,
provider: Provider,
schema: string | undefined,
): Promise<JSONToolOutput<any>> {
if (!this.isReadOnlyQuery(query)) {
return new JSONToolOutput({
Expand All @@ -107,29 +157,24 @@ export class SQLTool extends Tool<JSONToolOutput<any>, ToolOptions, ToolRunOptio
});
}

let metadata = "";
const provider = this.options.provider;
const sequelize = await this.connection();
const schema = this.options.connection.schema;

try {
metadata = await getMetadata(sequelize, provider, schema);

const sequelize = await this.connection();
const [results] = await sequelize.query(query);

if (Array.isArray(results) && results.length > 0) {
return new JSONToolOutput({ success: true, results });
} else {
return new JSONToolOutput({
success: false,
message: `No rows selected`,
});
}

return new JSONToolOutput({
success: false,
message: `No rows selected`,
});
} catch (error) {
const errorMessage = `Based on this database schema structure: ${metadata},
generate a correct query that retrieves data using the appropriate ${provider} dialect.
const schemaHint = schema
? `Fully qualify the table names by appending the schema name "${schema}" as a prefix, for example: ${schema}.table_name`
: "";
const errorMessage = `Generate a correct query that retrieves data using the appropriate ${provider} dialect.
${schemaHint}
The original request was: ${query}, and the error was: ${error.message}.`;

throw new ToolError(errorMessage);
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/tools/database/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ describe("SQLTool", () => {

it("Returns an error for invalid query", async () => {
const response = await instance.run({
action: "QUERY",
query: "DELETE FROM users",
});

Expand Down

0 comments on commit cc221a3

Please sign in to comment.