Skip to content

Commit

Permalink
refactor: combine exec and query database tools
Browse files Browse the repository at this point in the history
- Combine the `Run Database Query` and `Exec Database Statement` tools
  into a single `Run Database Command` tool that allows for executing both
  statements and queries. The new tool returns the raw SQL output
  instead of JSON
- Shell out to the sqlite3 binary instead of using a dependency-free go
  module
- Add the `list_database_tables` and `list_database_table_rows`
  tools to retain database support in the user UI.

Signed-off-by: Nick Hale <[email protected]>
  • Loading branch information
njhale committed Jan 24, 2025
1 parent 7e1f923 commit 3a13f9e
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 177 deletions.
1 change: 1 addition & 0 deletions database/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module obot-platform/database
go 1.23.3

require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b
github.com/ncruces/go-sqlite3 v0.20.3
)
Expand Down
2 changes: 2 additions & 0 deletions database/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicb
github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b h1:EDd5OCtZ43YVSzKuQlXLiXCIQ6qhsrqLqY5Ows5ohlY=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q=
github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY=
Expand Down
21 changes: 6 additions & 15 deletions database/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
Expand Down Expand Up @@ -65,28 +64,20 @@ func main() {
}
}

// Open the SQLite database
db, err := sql.Open("sqlite3", dbFile.Name())
if err != nil {
fmt.Printf("Error opening DB: %v\n", err)
os.Exit(1)
}
defer db.Close()

// Run the requested command
var result string
switch command {
case "listDatabaseTables":
result, err = cmd.ListDatabaseTables(ctx, db)
case "execDatabaseStatement":
result, err = cmd.ExecDatabaseStatement(ctx, db, os.Getenv("STATEMENT"))
result, err = cmd.ListDatabaseTables(ctx, dbFile)
case "listDatabaseTableRows":
result, err = cmd.ListDatabaseTableRows(ctx, dbFile, os.Getenv("TABLE"))
case "runDatabaseCommand":
result, err = cmd.RunDatabaseCommand(ctx, dbFile, os.Getenv("SQLITE3_ARGS"))
if err == nil {
err = saveWorkspaceDB(ctx, g, dbWorkspacePath, dbFile, initialDBData)
}
case "runDatabaseQuery":
result, err = cmd.RunDatabaseQuery(ctx, db, os.Getenv("QUERY"))
case "databaseContext":
result, err = cmd.DatabaseContext(ctx, db)
result, err = cmd.DatabaseContext(ctx, dbFile)
default:
err = fmt.Errorf("unknown command: %s", command)
}
Expand Down
39 changes: 39 additions & 0 deletions database/pkg/cmd/command.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package cmd

import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"strings"

"github.com/google/shlex"
)

// RunDatabaseCommand runs a sqlite3 command against the database and returns the output from the sqlite3 cli.
func RunDatabaseCommand(ctx context.Context, dbFile *os.File, sqlite3Args string) (string, error) {
// Remove the "sqlite3" prefix and trim whitespace
sqlite3Args = strings.TrimPrefix(strings.TrimSpace(sqlite3Args), "sqlite3")

// Split the arguments
args, err := shlex.Split(sqlite3Args)
if err != nil {
return "", fmt.Errorf("error parsing sqlite3 args: %w", err)
}

// Build the sqlite3 command
cmd := exec.CommandContext(ctx, "sqlite3", append([]string{dbFile.Name()}, args...)...)

// Redirect the command output
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

// Run the command
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("error executing sqlite3: %w, stderr: %s", err, stderr.String())
}

return stdout.String(), nil
}
68 changes: 27 additions & 41 deletions database/pkg/cmd/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,55 @@ package cmd

import (
"context"
"database/sql"
"fmt"
"os"
"strings"
)

// DatabaseContext returns the context text for the SQLite tools.
// The resulting string contains all of schemas in the database.
func DatabaseContext(ctx context.Context, db *sql.DB) (string, error) {
// Build the markdown output
var out strings.Builder
out.WriteString(`# START INSTRUCTIONS: Database Tools
// DatabaseContext generates a markdown-formatted string with instructions
// and the database's current schemas.
func DatabaseContext(ctx context.Context, dbFile *os.File) (string, error) {
var builder strings.Builder

// Add usage instructions
builder.WriteString(`# START INSTRUCTIONS: Run Database Command tool
You have access to tools for interacting with a SQLite database.
The "List Database Tables" tool returns a list of tables in the database.
The "Exec Database Statement" tool only accepts valid SQLite3 statements.
The "Run Database Query" tool only accepts valid SQLite3 queries.
The "Run Database Command" tool is a wrapper around the sqlite3 CLI, and the "sqlite3_args" argument will be passed to it.
Do not include a database file argument in the "sqlite3_args" argument. The database file is automatically passed to the sqlite3 CLI.
Ensure SQL statements are properly encapsulated in quotes to be recognized as complete inputs by the SQLite interface.
Display all results from these tools and their schemas in markdown format.
If the user refers to creating or modifying tables assume they mean a SQLite3 table and not writing a table
in a markdown file.
If the user refers to creating or modifying tables, assume they mean a SQLite3 table and not writing a table in a markdown file.
# END INSTRUCTIONS: Database Tools
# END INSTRUCTIONS: Run Database Command tool
`)

// Add the schemas section
schemas, err := getSchemas(ctx, db)
schemas, err := getSchemas(ctx, dbFile)
if err != nil {
return "", fmt.Errorf("failed to retrieve schemas: %w", err)
}
if schemas != "" {
out.WriteString("# START CURRENT DATABASE SCHEMAS\n")
out.WriteString(schemas)
out.WriteString("\n# END CURRENT DATABASE SCHEMAS\n")
builder.WriteString("# START CURRENT DATABASE SCHEMAS\n")
builder.WriteString(schemas)
builder.WriteString("\n# END CURRENT DATABASE SCHEMAS\n")
} else {
out.WriteString("# DATABASE HAS NO TABLES\n")
builder.WriteString("# DATABASE HAS NO TABLES\n")
}

return out.String(), nil
return builder.String(), nil
}

// getSchemas returns an SQL string containing all schemas in the database.
func getSchemas(ctx context.Context, db *sql.DB) (string, error) {
query := "SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name"
// getSchemas retrieves all schemas from the database using the sqlite3 CLI.
func getSchemas(ctx context.Context, dbFile *os.File) (string, error) {
query := `SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name;`

rows, err := db.QueryContext(ctx, query)
// Execute the query using the RunDatabaseCommand function
output, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query))
if err != nil {
return "", fmt.Errorf("failed to query sqlite_master: %w", err)
}
defer rows.Close()

var out strings.Builder
for rows.Next() {
var schema string
if err := rows.Scan(&schema); err != nil {
return "", fmt.Errorf("failed to scan schema: %w", err)
}
if schema != "" {
out.WriteString(fmt.Sprintf("\n%s\n", schema))
}
}

if rows.Err() != nil {
return "", fmt.Errorf("error iterating over schemas: %w", rows.Err())
return "", fmt.Errorf("error querying schemas: %w", err)
}

return out.String(), nil
// Return raw output as-is
return strings.TrimSpace(output), nil
}
16 changes: 0 additions & 16 deletions database/pkg/cmd/exec.go

This file was deleted.

62 changes: 0 additions & 62 deletions database/pkg/cmd/query.go

This file was deleted.

65 changes: 65 additions & 0 deletions database/pkg/cmd/rows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package cmd

import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
)

type Output struct {
Columns []string `json:"columns"`
Rows []map[string]any `json:"rows"`
}

// ListDatabaseTableRows lists all rows from the specified table using RunDatabaseCommand and returns a JSON object containing the results.
func ListDatabaseTableRows(ctx context.Context, dbFile *os.File, table string) (string, error) {
if table == "" {
return "", fmt.Errorf("table name cannot be empty")
}

// Build the query to fetch all rows from the table
query := fmt.Sprintf("SELECT * FROM %q;", table)

// Execute the query using RunDatabaseCommand
rawOutput, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query))
if err != nil {
return "", fmt.Errorf("error executing query for table %q: %w", table, err)
}

// Split raw output into rows
lines := strings.Split(strings.TrimSpace(rawOutput), "\n")
if len(lines) == 0 {
return "", fmt.Errorf("no output from query for table %q", table)
}

// The first line contains column names
columns := strings.Split(lines[0], "|")
output := Output{
Columns: columns,
Rows: []map[string]any{},
}

// Process the remaining lines as rows
for _, line := range lines[1:] {
values := strings.Split(line, "|")
rowData := map[string]any{}
for i, col := range columns {
if i < len(values) {
rowData[col] = values[i]
} else {
rowData[col] = nil
}
}
output.Rows = append(output.Rows, rowData)
}

// Marshal the result to JSON
content, err := json.Marshal(output)
if err != nil {
return "", fmt.Errorf("error marshalling output to JSON: %w", err)
}

return string(content), nil
}
Loading

0 comments on commit 3a13f9e

Please sign in to comment.