diff --git a/database/go.mod b/database/go.mod index 74ae14f1..dfe1f6f5 100644 --- a/database/go.mod +++ b/database/go.mod @@ -3,6 +3,7 @@ module obot-platform/database go 1.23.3 require ( + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b github.com/ncruces/go-sqlite3 v0.20.3 ) diff --git a/database/go.sum b/database/go.sum index a2bdeeb3..f0f35ba1 100644 --- a/database/go.sum +++ b/database/go.sum @@ -8,6 +8,8 @@ github.com/go-openapi/swag v0.22.8 h1:/9RjDSQ0vbFR+NyjGMkFTsA1IA0fmhKSThmfGZjicb github.com/go-openapi/swag v0.22.8/go.mod h1:6QT22icPLEqAM/z/TChgb4WAveCHF92+2gF0CNjHpPI= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b h1:EDd5OCtZ43YVSzKuQlXLiXCIQ6qhsrqLqY5Ows5ohlY= github.com/gptscript-ai/go-gptscript v0.9.6-0.20241023195750-c09e0f56b39b/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= diff --git a/database/main.go b/database/main.go index da4ccd27..83abec1a 100644 --- a/database/main.go +++ b/database/main.go @@ -3,7 +3,6 @@ package main import ( "context" "crypto/sha256" - "database/sql" "encoding/hex" "errors" "fmt" @@ -65,28 +64,20 @@ func main() { } } - // Open the SQLite database - db, err := sql.Open("sqlite3", dbFile.Name()) - if err != nil { - fmt.Printf("Error opening DB: %v\n", err) - os.Exit(1) - } - defer db.Close() - // Run the requested command var result string switch command { case "listDatabaseTables": - result, err = cmd.ListDatabaseTables(ctx, db) - case "execDatabaseStatement": - result, err = cmd.ExecDatabaseStatement(ctx, db, os.Getenv("STATEMENT")) + result, err = cmd.ListDatabaseTables(ctx, dbFile) + case "listDatabaseTableRows": + result, err = cmd.ListDatabaseTableRows(ctx, dbFile, os.Getenv("TABLE")) + case "runDatabaseCommand": + result, err = cmd.RunDatabaseCommand(ctx, dbFile, os.Getenv("SQLITE3_ARGS")) if err == nil { err = saveWorkspaceDB(ctx, g, dbWorkspacePath, dbFile, initialDBData) } - case "runDatabaseQuery": - result, err = cmd.RunDatabaseQuery(ctx, db, os.Getenv("QUERY")) case "databaseContext": - result, err = cmd.DatabaseContext(ctx, db) + result, err = cmd.DatabaseContext(ctx, dbFile) default: err = fmt.Errorf("unknown command: %s", command) } diff --git a/database/pkg/cmd/command.go b/database/pkg/cmd/command.go new file mode 100644 index 00000000..e6cf7280 --- /dev/null +++ b/database/pkg/cmd/command.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/google/shlex" +) + +// RunDatabaseCommand runs a sqlite3 command against the database and returns the output from the sqlite3 cli. +func RunDatabaseCommand(ctx context.Context, dbFile *os.File, sqlite3Args string) (string, error) { + // Remove the "sqlite3" prefix and trim whitespace + sqlite3Args = strings.TrimPrefix(strings.TrimSpace(sqlite3Args), "sqlite3") + + // Split the arguments + args, err := shlex.Split(sqlite3Args) + if err != nil { + return "", fmt.Errorf("error parsing sqlite3 args: %w", err) + } + + // Build the sqlite3 command + cmd := exec.CommandContext(ctx, "sqlite3", append([]string{dbFile.Name()}, args...)...) + + // Redirect the command output + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Run the command + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("error executing sqlite3: %w, stderr: %s", err, stderr.String()) + } + + return stdout.String(), nil +} diff --git a/database/pkg/cmd/context.go b/database/pkg/cmd/context.go index 740e02ce..d90b0272 100644 --- a/database/pkg/cmd/context.go +++ b/database/pkg/cmd/context.go @@ -2,69 +2,55 @@ package cmd import ( "context" - "database/sql" "fmt" + "os" "strings" ) -// DatabaseContext returns the context text for the SQLite tools. -// The resulting string contains all of schemas in the database. -func DatabaseContext(ctx context.Context, db *sql.DB) (string, error) { - // Build the markdown output - var out strings.Builder - out.WriteString(`# START INSTRUCTIONS: Database Tools +// DatabaseContext generates a markdown-formatted string with instructions +// and the database's current schemas. +func DatabaseContext(ctx context.Context, dbFile *os.File) (string, error) { + var builder strings.Builder + + // Add usage instructions + builder.WriteString(`# START INSTRUCTIONS: Run Database Command tool You have access to tools for interacting with a SQLite database. -The "List Database Tables" tool returns a list of tables in the database. -The "Exec Database Statement" tool only accepts valid SQLite3 statements. -The "Run Database Query" tool only accepts valid SQLite3 queries. +The "Run Database Command" tool is a wrapper around the sqlite3 CLI, and the "sqlite3_args" argument will be passed to it. +Do not include a database file argument in the "sqlite3_args" argument. The database file is automatically passed to the sqlite3 CLI. +Ensure SQL statements are properly encapsulated in quotes to be recognized as complete inputs by the SQLite interface. Display all results from these tools and their schemas in markdown format. -If the user refers to creating or modifying tables assume they mean a SQLite3 table and not writing a table -in a markdown file. +If the user refers to creating or modifying tables, assume they mean a SQLite3 table and not writing a table in a markdown file. -# END INSTRUCTIONS: Database Tools +# END INSTRUCTIONS: Run Database Command tool `) // Add the schemas section - schemas, err := getSchemas(ctx, db) + schemas, err := getSchemas(ctx, dbFile) if err != nil { return "", fmt.Errorf("failed to retrieve schemas: %w", err) } if schemas != "" { - out.WriteString("# START CURRENT DATABASE SCHEMAS\n") - out.WriteString(schemas) - out.WriteString("\n# END CURRENT DATABASE SCHEMAS\n") + builder.WriteString("# START CURRENT DATABASE SCHEMAS\n") + builder.WriteString(schemas) + builder.WriteString("\n# END CURRENT DATABASE SCHEMAS\n") } else { - out.WriteString("# DATABASE HAS NO TABLES\n") + builder.WriteString("# DATABASE HAS NO TABLES\n") } - return out.String(), nil + return builder.String(), nil } -// getSchemas returns an SQL string containing all schemas in the database. -func getSchemas(ctx context.Context, db *sql.DB) (string, error) { - query := "SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name" +// getSchemas retrieves all schemas from the database using the sqlite3 CLI. +func getSchemas(ctx context.Context, dbFile *os.File) (string, error) { + query := `SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name;` - rows, err := db.QueryContext(ctx, query) + // Execute the query using the RunDatabaseCommand function + output, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query)) if err != nil { - return "", fmt.Errorf("failed to query sqlite_master: %w", err) - } - defer rows.Close() - - var out strings.Builder - for rows.Next() { - var schema string - if err := rows.Scan(&schema); err != nil { - return "", fmt.Errorf("failed to scan schema: %w", err) - } - if schema != "" { - out.WriteString(fmt.Sprintf("\n%s\n", schema)) - } - } - - if rows.Err() != nil { - return "", fmt.Errorf("error iterating over schemas: %w", rows.Err()) + return "", fmt.Errorf("error querying schemas: %w", err) } - return out.String(), nil + // Return raw output as-is + return strings.TrimSpace(output), nil } diff --git a/database/pkg/cmd/exec.go b/database/pkg/cmd/exec.go deleted file mode 100644 index bfe68dcd..00000000 --- a/database/pkg/cmd/exec.go +++ /dev/null @@ -1,16 +0,0 @@ -package cmd - -import ( - "context" - "database/sql" - "fmt" -) - -// ExecDatabaseStatement executes a SQL statement (e.g., INSERT, UPDATE, DELETE, CREATE) and returns a status message. -func ExecDatabaseStatement(ctx context.Context, db *sql.DB, stmt string) (string, error) { - _, err := db.ExecContext(ctx, stmt) - if err != nil { - return "", fmt.Errorf("error executing SQL: %w", err) - } - return "Command executed successfully.", nil -} diff --git a/database/pkg/cmd/query.go b/database/pkg/cmd/query.go deleted file mode 100644 index 7f2f17f1..00000000 --- a/database/pkg/cmd/query.go +++ /dev/null @@ -1,62 +0,0 @@ -package cmd - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" -) - -type Output struct { - Columns []string `json:"columns"` - Rows []map[string]any `json:"rows"` -} - -// RunDatabaseQuery executes a SQL query (e.g. SELECT) and returns a JSON object containing the results. -func RunDatabaseQuery(ctx context.Context, db *sql.DB, query string) (string, error) { - if query == "" { - return "", fmt.Errorf("empty query") - } - - rows, err := db.QueryContext(ctx, query) - if err != nil { - return "", fmt.Errorf("error executing query: %v", err) - } - defer rows.Close() - - // Retrieve column names - columns, err := rows.Columns() - if err != nil { - return "", fmt.Errorf("error retrieving columns: %v", err) - } - - var output = Output{ - Columns: columns, - Rows: []map[string]any{}, - } - - // Prepare a slice of interface{} for each row's column values - values := make([]interface{}, len(columns)) - valuePointers := make([]interface{}, len(columns)) - for i := range values { - valuePointers[i] = &values[i] - } - - // Fetch rows and write their contents - for rows.Next() { - err := rows.Scan(valuePointers...) - if err != nil { - return "", fmt.Errorf("error scanning row: %w", err) - } - - // Convert values to strings - rowData := map[string]any{} - for i, val := range values { - rowData[columns[i]] = val - } - output.Rows = append(output.Rows, rowData) - } - - content, err := json.Marshal(output) - return string(content), err -} diff --git a/database/pkg/cmd/rows.go b/database/pkg/cmd/rows.go new file mode 100644 index 00000000..8241dd3c --- /dev/null +++ b/database/pkg/cmd/rows.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" +) + +type Output struct { + Columns []string `json:"columns"` + Rows []map[string]any `json:"rows"` +} + +// ListDatabaseTableRows lists all rows from the specified table using RunDatabaseCommand and returns a JSON object containing the results. +func ListDatabaseTableRows(ctx context.Context, dbFile *os.File, table string) (string, error) { + if table == "" { + return "", fmt.Errorf("table name cannot be empty") + } + + // Build the query to fetch all rows from the table + query := fmt.Sprintf("SELECT * FROM %q;", table) + + // Execute the query using RunDatabaseCommand + rawOutput, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query)) + if err != nil { + return "", fmt.Errorf("error executing query for table %q: %w", table, err) + } + + // Split raw output into rows + lines := strings.Split(strings.TrimSpace(rawOutput), "\n") + if len(lines) == 0 { + return "", fmt.Errorf("no output from query for table %q", table) + } + + // The first line contains column names + columns := strings.Split(lines[0], "|") + output := Output{ + Columns: columns, + Rows: []map[string]any{}, + } + + // Process the remaining lines as rows + for _, line := range lines[1:] { + values := strings.Split(line, "|") + rowData := map[string]any{} + for i, col := range columns { + if i < len(values) { + rowData[col] = values[i] + } else { + rowData[col] = nil + } + } + output.Rows = append(output.Rows, rowData) + } + + // Marshal the result to JSON + content, err := json.Marshal(output) + if err != nil { + return "", fmt.Errorf("error marshalling output to JSON: %w", err) + } + + return string(content), nil +} diff --git a/database/pkg/cmd/table.go b/database/pkg/cmd/table.go index 45ab94c6..2b537204 100644 --- a/database/pkg/cmd/table.go +++ b/database/pkg/cmd/table.go @@ -2,50 +2,58 @@ package cmd import ( "context" - "database/sql" "encoding/json" "fmt" + "os" + "strings" ) +type tables struct { + Tables []Table `json:"tables"` +} + +type Table struct { + Name string `json:"name,omitempty"` +} + // ListDatabaseTables returns a JSON object containing the list of tables in the database. -func ListDatabaseTables(ctx context.Context, db *sql.DB) (string, error) { - tables, err := listTables(ctx, db) +func ListDatabaseTables(ctx context.Context, dbFile *os.File) (string, error) { + tables, err := listTables(ctx, dbFile) if err != nil { return "", fmt.Errorf("failed to list tables: %w", err) } content, err := json.Marshal(tables) - return string(content), err -} + if err != nil { + return "", fmt.Errorf("failed to marshal tables to JSON: %w", err) + } -type tables struct { - Tables []Table `json:"tables"` + return string(content), nil } -type Table struct { - Name string `json:"name,omitempty"` -} +// listTables retrieves the list of tables in the database using RunDatabaseCommand. +func listTables(ctx context.Context, dbFile *os.File) (tables, error) { + // Query to fetch table names + query := "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';" -func listTables(ctx context.Context, db *sql.DB) (tables, error) { - rows, err := db.QueryContext(ctx, "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") + // Execute the query using RunDatabaseCommand + rawOutput, err := RunDatabaseCommand(ctx, dbFile, fmt.Sprintf("%q", query)) if err != nil { - return tables{}, fmt.Errorf("failed to query tables: %w", err) + return tables{}, fmt.Errorf("error executing query to list tables: %w", err) } - defer rows.Close() - var tables tables - for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { - return tables, fmt.Errorf("failed to scan table name: %w", err) - } - tables.Tables = append(tables.Tables, Table{ - Name: tableName, - }) + // Process the output + lines := strings.Split(strings.TrimSpace(rawOutput), "\n") + if len(lines) == 0 { + return tables{}, nil // No tables found } - if rows.Err() != nil { - return tables, fmt.Errorf("error iterating over table names: %w", rows.Err()) + + var result tables + for _, line := range lines { + if line = strings.TrimSpace(line); line != "" { + result.Tables = append(result.Tables, Table{Name: line}) + } } - return tables, nil + return result, nil } diff --git a/database/tool.gpt b/database/tool.gpt index 7afaa83e..86034e2d 100644 --- a/database/tool.gpt +++ b/database/tool.gpt @@ -3,32 +3,33 @@ Name: Database Description: Tools for interacting with a database Metadata: category: Capability Metadata: icon: https://cdn.jsdelivr.net/npm/@phosphor-icons/core@2/assets/duotone/database-duotone.svg -Share Tools: Run Database Query, Exec Database Statement +Share Tools: Run Database Command --- -Name: List Database Tables -Description: List all tables in the SQLite database and return a JSON object containing the results +Name: Run Database Command +Share Context: Database Context +Description: Run the sqlite3 command with the given arguments against the SQLite database and print the results +Param: sqlite3_args: Arguments to pass to the sqlite3 cli -#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool listDatabaseTables +#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool runDatabaseCommand --- -Name: Run Database Query -Description: Run a SQL query against the SQLite database and return a JSON object containing the results -Share Context: Database Context -Param: query: SQL query to run +Name: Database Context +Type: context -#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool runDatabaseQuery +#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool databaseContext --- -Name: Exec Database Statement -Description: Execute a SQL statement against the SQLite database -Share Context: Database Context -Param: statement: SQL statement to execute +Name: list_database_tables +Metadata: index: false +Description: List all tables in the SQLite database and return a JSON object containing the results -#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool execDatabaseStatement +#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool listDatabaseTables --- -Name: Database Context -Type: context +Name: list_database_table_rows +Metadata: index: false +Description: List all rows from the specified table in the SQLite database and return a JSON object containing the results +Param: table: Name of the table to list rows from -#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool databaseContext +#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool listDatabaseTableRows