Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tool): add sql database query tool #24

Merged
merged 10 commits into from
Sep 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(tool): add sql database query tool
Mahmoud Abughali committed Sep 22, 2024
commit fc1baa3435f791fd89b5d9002ee55f4374917d81
23 changes: 19 additions & 4 deletions .env.template
Original file line number Diff line number Diff line change
@@ -3,14 +3,29 @@ BEE_FRAMEWORK_LOG_PRETTY=true
BEE_FRAMEWORK_LOG_LEVEL="info"
BEE_FRAMEWORK_LOG_SINGLE_LINE="false"

# Tools
CODE_INTERPRETER_URL=http://127.0.0.1:50051

# For BAM LLM Adapter
# GENAI_API_KEY=

# For OpenAI LLM Adapter
# OPENAI_API_KEY=

# For Groq LLM Adapter
# GROQ_API_KEY=
# GROQ_API_KEY=

# Tools
CODE_INTERPRETER_URL=http://127.0.0.1:50051

# SQL Tool Connection Properties

# For SQLite
# DB_DIALECT=sqlite
# DB_STORAGE=sqlite_file.db

# For other databases
# DB_NAME=database_name
# DB_USERNAME=username
# DB_PASSWORD=password
# DB_HOST=host
# DB_PORT=1521
# DB_DIALECT=oracle # Or mariadb, mysql, mssql, postgres, db2
# DB_SCHEMA=schema_name # Optional for mariadb, mysql, mssql, postgres
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -95,18 +95,20 @@ To run this example, be sure that you have installed [ollama](https://ollama.com

### 🛠️ Tools

| Name | Description |
| ------------------------------------------------------------------------- | ------------------------------------------------------------------------- |
| `PythonTool` | Run arbitrary Python code in the remote environment. |
| `WikipediaTool` | Search for data on Wikipedia. |
| `DuckDuckGoTool` | Search for data on DuckDuckGo. |
| `LLMTool` | Uses an LLM to process input data. |
| `DynamicTool` | Construct to create dynamic tools. |
| `ArXivTool` | Retrieves research articles published on arXiv. |
| `WebCrawlerTool` | Retrieves content of an arbitrary website. |
| `CustomTool` | Runs your own Python function in the remote environment. |
| `OpenMeteoTool` | Retrieves current, previous, or upcoming weather for a given destination. |
| ➕ [Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | |
| Name | Description |
| ------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ |
| `PythonTool` | Run arbitrary Python code in the remote environment. |
| `WikipediaTool` | Search for data on Wikipedia. |
| `DuckDuckGoTool` | Search for data on DuckDuckGo. |
| `SQLTool` | Executing SQL queries against various databases. [Instructions](./docs/sql-tool.md). |
| `CustomTool` | Runs your own Python function in the remote environment. |
| `LLMTool` | Uses an LLM to process input data. |
| `DynamicTool` | Construct to create dynamic tools. |
| `ArXivTool` | Retrieves research articles published on arXiv. |
| `WebCrawlerTool` | Retrieves content of an arbitrary website. |
| `CustomTool` | Runs your own Python function in the remote environment. |
| `OpenMeteoTool` | Retrieves current, previous, or upcoming weather for a given destination. |
| ➕ [Request](https://github.com/i-am-bee/bee-agent-framework/discussions) | |

### 🔌️ Adapters (LLM - Inference providers)

149 changes: 149 additions & 0 deletions docs/sql-tool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# 🛢️ SQLTool

## Description

This tool executes SQL queries and guides the agent in constructing correct queries based on the structures of all available tables. It simplifies data retrieval, allowing users without advanced SQL knowledge to effectively query databases using natural language. The tool is designed for safety, allowing only `SELECT` queries to prevent any accidental changes to the database.

## Requirements

To use SQLTool with different databases, ensure you have the following:

- **Sequelize**: Version 6
- **Database Connector Package**: The appropriate package for your database (e.g., `ibm_db`, `mysql2`, `sqlite3`, etc.)

## Installation

Follow the table below to install the required packages for your database:

| Database | Required Package(s) |
| -------------------- | --------------------- |
| IBM Db2 for LUW | `node-gyp`, `ibm_db` |
| SQLite | `sqlite3` |
| MariaDB | `mariadb` (version 2) |
| MySQL | `mysql2` |
| PostgreSQL | `pg` |
| Microsoft SQL Server | `tedious` |
| Oracle | `oracledb` |

To install Sequelize, run the following command:

```bash
yarn add sequelize
```

Depending on the database you're using, install the required package(s). For example, if you're using IBM Db2 for LUW, run:

```bash
yarn add node-gyp ibm_db
```

Replace `node-gyp` `ibm_db` with the appropriate package(s) for your database from the table above.

## Database Connection Configuration

Configure the connection to your database using environment variables. Sample configuration is provided in `.env.template` file to guide you in setting up your `.env` file with appropriate values for your database.

```bash
# SQL Tool Connection Properties

# For SQLite
DB_DIALECT=sqlite
DB_STORAGE=sqlite_file.db

# For other databases
DB_NAME=database_name
DB_USERNAME=username
DB_PASSWORD=password
DB_HOST=host
DB_PORT=1521
DB_DIALECT=oracle # Or mariadb, mysql, mssql, postgres, db2
DB_SCHEMA=schema_name # Optional for mariadb, mysql, mssql, postgres
```

### Usage

In order to use the SQLTool, you need to pass the database provider name to the tool.
Supported values are: `mysql`, `mariadb`, `postgres`, `mssql`, `db2`, `sqlite`, `oracle`.

```js
import { SQLTool } from "bee-agent-framework/tools/database/sql";

const sqlTool = new SQLTool({
provider: "sqlite",
});
```

## Examples

Below are examples showcasing how to perform a search on the [chinook](https://github.com/lerocha/chinook-database) SQLite database:

**_Question 1_**

_Which country's customers spent the most?_

```
Agent (thought) 🤖 : I need to access the database to find the answer, and I have a SQLTool that can help me with that.
Agent (tool_name) 🤖 : SQLTool
Agent (tool_caption) 🤖 : Executing SQL query to find the country with the highest customer spending
Agent (tool_input) 🤖 : {"query": "SELECT country, SUM(amount) as total_spent FROM customer_spending GROUP BY country ORDER BY total_spent DESC LIMIT 1"}
Agent (tool_output) 🤖 : {"error":"Based on this database schema structure: Table 'album' with columns: album_id (INTEGER), title (NVARCHAR(160)), artist_id (INTEGER); Table 'artist' with columns: artist_id (INTEGER), name (NVARCHAR(120)); Table 'customer' with columns: customer_id (INTEGER), first_name (NVARCHAR(40)), last_name (NVARCHAR(20)), company (NVARCHAR(80)), address (NVARCHAR(70)), city (NVARCHAR(40)), state (NVARCHAR(40)), country (NVARCHAR(40)), postal_code (NVARCHAR(10)), phone (NVARCHAR(24)), fax (NVARCHAR(24)), email (NVARCHAR(60)), support_rep_id (INTEGER); Table 'employee' with columns: employee_id (INTEGER), last_name (NVARCHAR(20)), first_name (NVARCHAR(20)), title (NVARCHAR(30)), reports_to (INTEGER), birthdate (DATETIME), hire_date (DATETIME), address (NVARCHAR(70)), city (NVARCHAR(40)), state (NVARCHAR(40)), country (NVARCHAR(40)), postal_code (NVARCHAR(10)), phone (NVARCHAR(24)), fax (NVARCHAR(24)), email (NVARCHAR(60)); Table 'genre' with columns: genre_id (INTEGER), name (NVARCHAR(120)); Table 'invoice' with columns: invoice_id (INTEGER), customer_id (INTEGER), invoice_date (DATETIME), billing_address (NVARCHAR(70)), billing_city (NVARCHAR(40)), billing_state (NVARCHAR(40)), billing_country (NVARCHAR(40)), billing_postal_code (NVARCHAR(10)), total (NUMERIC(10,2)); Table 'invoice_line' with columns: invoice_line_id (INTEGER), invoice_id (INTEGER), track_id (INTEGER), unit_price (NUMERIC(10,2)), quantity (INTEGER); Table 'media_type' with columns: media_type_id (INTEGER), name (NVARCHAR(120)); Table 'playlist' with columns: playlist_id (INTEGER), name (NVARCHAR(120)); Table 'playlist_track' with columns: playlist_id (INTEGER), track_id (INTEGER); Table 'track' with columns: track_id (INTEGER), name (NVARCHAR(200)), album_id (INTEGER), media_type_id (INTEGER), genre_id (INTEGER), composer (NVARCHAR(220)), milliseconds (INTEGER), bytes (INTEGER), unit_price (NUMERIC(10,2)), \n generate a correct query that retrieves data using the appropriate sqlite dialect. \n The original request was: SELECT country, SUM(amount) as total_spent FROM customer_spending GROUP BY country ORDER BY total_spent DESC LIMIT 1, and the error was: SQLITE_ERROR: no such table: customer_spending."}
```

The agent initially retrieved incorrect information from the database, but corrected itself on the second attempt and successfully generated the correct query.

```
Agent (thought) 🤖 : It seems that the SQLTool didn't work as expected, and I need to adjust the query to match the available tables in the database.
Agent (tool_name) 🤖 : SQLTool
Agent (tool_caption) 🤖 : Executing SQL query to find the country with the highest customer spending
Agent (tool_input) 🤖 : {"query":"SELECT i.billing_country, SUM(i.total) as total_spent FROM invoice i GROUP BY i.billing_country ORDER BY total_spent DESC LIMIT 1"}
Agent (tool_output) 🤖 : {"success":true,"results":[{"billing_country":"USA","total_spent":1040.49}]}
Agent (thought) 🤖 : The SQLTool was successful this time, and I have the answer to the question.
Agent (final_answer) 🤖 : The country whose customers spent the most is the USA, with a total spending of $1040.49.
```

**_Generated SQL_**

```sql
SELECT country,
SUM(amount) AS total_spent
FROM customer_spending
GROUP BY country
ORDER BY total_spent DESC
LIMIT 1;
```

**_Answer_**

_The country whose customers spent the most is the USA, with a total spending of $1040.49._

**_Question 2_**

_Show the top 3 best selling artists in terms of revenue._

**_Generated SQL_**

```sql
SELECT a.name,
SUM(il.unit_price * il.quantity) AS total_revenue
FROM invoice_line il
JOIN
track t ON il.track_id = t.track_id
JOIN
album al ON t.album_id = al.album_id
JOIN
artist a ON al.artist_id = a.artist_id
GROUP BY a.name
ORDER BY total_revenue DESC
LIMIT 3;
```

**_Answer_**

_The top 3 best selling artists in terms of revenue are:_
_- Queen with a total revenue of \$190.08_
_- Jimi Hendrix with a total revenue of \$185.13_
_- Red Hot Chili Peppers with a total revenue of \$128.77._

## Sample Agent

A complete sample of an SQL agent is available [here](../examples/agents/sql.ts).
57 changes: 57 additions & 0 deletions examples/agents/sql.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import "dotenv/config.js";
import { BeeAgent } from "bee-agent-framework/agents/bee/agent";
import { GroqChatLLM } from "bee-agent-framework/adapters/groq/chat";
import { SQLTool } from "bee-agent-framework/tools/database/sql";
import { FrameworkError } from "bee-agent-framework/errors";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";

const llm = new GroqChatLLM({
modelId: "llama3-70b-8192",
parameters: {
temperature: 0,
max_tokens: 8192,
},
});

const sqlTool = new SQLTool({
provider: "sqlite",
});

const agent = new BeeAgent({
llm,
memory: new UnconstrainedMemory(),
tools: [sqlTool],
});

const question = "which country's customers spent the most?";

try {
const response = await agent
.run(
{ prompt: "From the database: " + question },
abughali marked this conversation as resolved.
Show resolved Hide resolved
{
execution: {
maxRetriesPerStep: 5,
totalMaxRetries: 10,
maxIterations: 15,
},
},
)
.observe((emitter) => {
emitter.on("error", ({ error }) => {
console.log(`Agent 🤖 : `, FrameworkError.ensure(error).dump());
});
emitter.on("retry", () => {
console.log(`Agent 🤖 : `, "retrying the action...");
});
emitter.on("update", async ({ data, update, meta }) => {
console.log(`Agent (${update.key}) 🤖 : `, update.value);
});
});

console.log(`Agent 🤖 : `, response.result.text);
} catch (error) {
console.error(FrameworkError.ensure(error).dump());
} finally {
process.exit(0);
}
6 changes: 5 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -119,7 +119,9 @@
"groq-sdk": "^0.7.0",
"ollama": "^0.5.8",
"openai": "^4.56.0",
"openai-chat-tokens": "^0.2.8"
"openai-chat-tokens": "^0.2.8",
"sequelize": "^6.37.3",
"sqlite3": "^5.1.7"
abughali marked this conversation as resolved.
Show resolved Hide resolved
},
"devDependencies": {
"@commitlint/cli": "^19.4.1",
@@ -162,6 +164,8 @@
"prettier": "^3.3.3",
"release-it": "^17.6.0",
"rimraf": "^6.0.1",
"sequelize": "^6.37.3",
"sqlite3": "^5.1.7",
"strip-ansi": "^7.1.0",
"temp-dir": "^3.0.0",
"tsc-files": "^1.1.4",
56 changes: 56 additions & 0 deletions src/tools/database/connection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* Copyright 2024 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { Sequelize, Dialect } from "sequelize";

let dbSchema: string | undefined = undefined;

export async function connectSql(): Promise<Sequelize> {
abughali marked this conversation as resolved.
Show resolved Hide resolved
try {
const dbName = process.env.DB_NAME as string;
const dbUser = process.env.DB_USERNAME as string;
const dbPassword = process.env.DB_PASSWORD;
const dbHost = process.env.DB_HOST;
const dbPort = process.env.DB_PORT ? Number(process.env.DB_PORT) : undefined;
const dbDialect = process.env.DB_DIALECT as Dialect;
const dbStorage = process.env.DB_STORAGE;

dbSchema = process.env.DB_SCHEMA;

const sequelize = new Sequelize(dbName, dbUser, dbPassword, {
host: dbHost,
dialect: dbDialect,
port: dbPort,
schema: dbSchema,
storage: dbStorage,
logging: false,
});

await sequelize.authenticate();
return sequelize;
} catch (error) {
throw Error(`Unable to connect to the SQL database: ${error.message}`);
abughali marked this conversation as resolved.
Show resolved Hide resolved
}
}

export function getSchema(): string | undefined {
return dbSchema || undefined;
}

export function isSupported(provider: string): boolean {
const supportedProviders = ["mysql", "mariadb", "postgres", "mssql", "db2", "sqlite", "oracle"];
return supportedProviders.includes(provider);
}
Loading