Skip to content

Commit

Permalink
refactor: combine exec and query database tools
Browse files Browse the repository at this point in the history
- Combine the `Run Database Query` and `Exec Database Statement` tools
  into a single `Run Database Command` tool that allows for executing both
  statements and queries. The new tool returns the raw SQL output
  instead of JSON
- Shell out to the sqlite3 binary instead of using a dependency-free go
  module
- Add the `list_database_tables` and `list_database_table_rows`
  system tools to retain database support in the user UI

Signed-off-by: Nick Hale <[email protected]>
  • Loading branch information
njhale committed Jan 27, 2025
1 parent 7e1f923 commit 5151d11
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 177 deletions.
1 change: 1 addition & 0 deletions database/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module obot-platform/database
go 1.23.3

require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b
github.com/ncruces/go-sqlite3 v0.20.3
)
Expand Down
2 changes: 2 additions & 0 deletions database/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicb
github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b h1:EDd5OCtZ43YVSzKuQlXLiXCIQ6qhsrqLqY5Ows5ohlY=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
Expand Down
21 changes: 6 additions & 15 deletions database/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -65,28 +64,20 @@ func main() {
}
}

// Open the SQLite database
db, err := sql.Open("sqlite3", dbFile.Name())
if err != nil {
fmt.Printf("Error opening DB: %v\n", err)
os.Exit(1)
}
defer db.Close()

// Run the requested command
var result string
switch command {
case "listDatabaseTables":
result, err = cmd.ListDatabaseTables(ctx, db)
case "execDatabaseStatement":
result, err = cmd.ExecDatabaseStatement(ctx, db, os.Getenv("STATEMENT"))
result, err = cmd.ListDatabaseTables(ctx, dbFile)
case "listDatabaseTableRows":
result, err = cmd.ListDatabaseTableRows(ctx, dbFile, os.Getenv("TABLE"))
case "runDatabaseCommand":
result, err = cmd.RunDatabaseCommand(ctx, dbFile, os.Getenv("SQLITE3_ARGS"))
if err == nil {
err = saveWorkspaceDB(ctx, g, dbWorkspacePath, dbFile, initialDBData)
}
case "runDatabaseQuery":
result, err = cmd.RunDatabaseQuery(ctx, db, os.Getenv("QUERY"))
case "databaseContext":
result, err = cmd.DatabaseContext(ctx, db)
result, err = cmd.DatabaseContext(ctx, dbFile)
default:
err = fmt.Errorf("unknown command: %s", command)
}
Expand Down
42 changes: 42 additions & 0 deletions database/pkg/cmd/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package cmd

import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"strings"

"github.com/google/shlex"
)

// RunDatabaseCommand runs a sqlite3 command against the database and returns the output from the sqlite3 CLI.
func RunDatabaseCommand(ctx context.Context, dbFile *os.File, sqlite3Args string) (string, error) {
// Remove the "sqlite3" prefix and trim whitespace
sqlite3Args = strings.TrimPrefix(strings.TrimSpace(sqlite3Args), "sqlite3")

// Split the arguments using shlex
args, err := shlex.Split(sqlite3Args)
if err != nil {
return "", fmt.Errorf("error parsing sqlite3 args: %w", err)
}

// Append the database file name as the first argument
args = append([]string{dbFile.Name()}, args...)

// Build the sqlite3 command
cmd := exec.CommandContext(ctx, "sqlite3", args...)

// Redirect command output
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

// Run the command and capture errors
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("error executing sqlite3: %w, stderr: %s", err, stderr.String())
}

return stdout.String(), nil
}
70 changes: 29 additions & 41 deletions database/pkg/cmd/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,57 @@ package cmd

import (
"context"
"database/sql"
"fmt"
"os"
"strings"
)

// DatabaseContext returns the context text for the SQLite tools.
// The resulting string contains all of schemas in the database.
func DatabaseContext(ctx context.Context, db *sql.DB) (string, error) {
// Build the markdown output
var out strings.Builder
out.WriteString(`# START INSTRUCTIONS: Database Tools
// DatabaseContext generates a markdown-formatted string with instructions
// and the database's current schemas.
func DatabaseContext(ctx context.Context, dbFile *os.File) (string, error) {
var builder strings.Builder

// Add usage instructions
builder.WriteString(`# START INSTRUCTIONS: Run Database Command tool
You have access to tools for interacting with a SQLite database.
The "List Database Tables" tool returns a list of tables in the database.
The "Exec Database Statement" tool only accepts valid SQLite3 statements.
The "Run Database Query" tool only accepts valid SQLite3 queries.
The "Run Database Command" tool is a wrapper around the sqlite3 CLI, and the "sqlite3_args" argument will be passed to it.
Do not include a database file argument in the "sqlite3_args" argument. The database file is automatically passed to the sqlite3 CLI.
**Ensure that all SQL statements are properly encapsulated in double quotes** to be recognized as complete inputs by the SQLite interface.
For example, use "\"CREATE TABLE example (id INTEGER);\"", not CREATE TABLE example (id INTEGER);.
This means you should wrap the entire SQL command string in double quotes, ensuring it is treated as a single argument.
Display all results from these tools and their schemas in markdown format.
If the user refers to creating or modifying tables assume they mean a SQLite3 table and not writing a table
in a markdown file.
If the user refers to creating or modifying tables, assume they mean a SQLite3 table and not writing a table in a markdown file.
# END INSTRUCTIONS: Database Tools
# END INSTRUCTIONS: Run Database Command tool
`)

// Add the schemas section
schemas, err := getSchemas(ctx, db)
schemas, err := getSchemas(ctx, dbFile)
if err != nil {
return "", fmt.Errorf("failed to retrieve schemas: %w", err)
}
if schemas != "" {
out.WriteString("# START CURRENT DATABASE SCHEMAS\n")
out.WriteString(schemas)
out.WriteString("\n# END CURRENT DATABASE SCHEMAS\n")
builder.WriteString("# START CURRENT DATABASE SCHEMAS\n")
builder.WriteString(schemas)
builder.WriteString("\n# END CURRENT DATABASE SCHEMAS\n")
} else {
out.WriteString("# DATABASE HAS NO TABLES\n")
builder.WriteString("# DATABASE HAS NO TABLES\n")
}

return out.String(), nil
return builder.String(), nil
}

// getSchemas returns an SQL string containing all schemas in the database.
func getSchemas(ctx context.Context, db *sql.DB) (string, error) {
query := "SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name"
// getSchemas retrieves all schemas from the database using the sqlite3 CLI.
func getSchemas(ctx context.Context, dbFile *os.File) (string, error) {
query := `SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name;`

rows, err := db.QueryContext(ctx, query)
// Execute the query using the RunDatabaseCommand function
output, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query))
if err != nil {
return "", fmt.Errorf("failed to query sqlite_master: %w", err)
}
defer rows.Close()

var out strings.Builder
for rows.Next() {
var schema string
if err := rows.Scan(&schema); err != nil {
return "", fmt.Errorf("failed to scan schema: %w", err)
}
if schema != "" {
out.WriteString(fmt.Sprintf("\n%s\n", schema))
}
}

if rows.Err() != nil {
return "", fmt.Errorf("error iterating over schemas: %w", rows.Err())
return "", fmt.Errorf("error querying schemas: %w", err)
}

return out.String(), nil
// Return raw output as-is
return strings.TrimSpace(output), nil
}
16 changes: 0 additions & 16 deletions database/pkg/cmd/exec.go

This file was deleted.

62 changes: 0 additions & 62 deletions database/pkg/cmd/query.go

This file was deleted.

65 changes: 65 additions & 0 deletions database/pkg/cmd/rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package cmd

import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
)

type Output struct {
Columns []string `json:"columns"`
Rows []map[string]any `json:"rows"`
}

// ListDatabaseTableRows lists all rows from the specified table using RunDatabaseCommand and returns a JSON object containing the results.
func ListDatabaseTableRows(ctx context.Context, dbFile *os.File, table string) (string, error) {
if table == "" {
return "", fmt.Errorf("table name cannot be empty")
}

// Build the query to fetch all rows from the table
query := fmt.Sprintf("SELECT * FROM %q;", table)

// Execute the query using RunDatabaseCommand
rawOutput, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query))
if err != nil {
return "", fmt.Errorf("error executing query for table %q: %w", table, err)
}

// Split raw output into rows
lines := strings.Split(strings.TrimSpace(rawOutput), "\n")
if len(lines) == 0 {
return "", fmt.Errorf("no output from query for table %q", table)
}

// The first line contains column names
columns := strings.Split(lines[0], "|")
output := Output{
Columns: columns,
Rows: []map[string]any{},
}

// Process the remaining lines as rows
for _, line := range lines[1:] {
values := strings.Split(line, "|")
rowData := map[string]any{}
for i, col := range columns {
if i < len(values) {
rowData[col] = values[i]
} else {
rowData[col] = nil
}
}
output.Rows = append(output.Rows, rowData)
}

// Marshal the result to JSON
content, err := json.Marshal(output)
if err != nil {
return "", fmt.Errorf("error marshalling output to JSON: %w", err)
}

return string(content), nil
}
Loading

0 comments on commit 5151d11

Please sign in to comment.