Skip to content

Commit

Permalink
chore: respect catalog parameter in DB.ListColumns (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
atzoum authored Mar 21, 2024
1 parent 4039fcc commit 11bf6b2
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 25 deletions.
12 changes: 11 additions & 1 deletion sqlconnect/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ import (
"time"
)

var ErrDropOldTablePostCopy = errors.New("move table: dropping old table after copying its contents to the new table")
var (
ErrNotSupported = errors.New("sqconnect: feature not supported")
ErrDropOldTablePostCopy = errors.New("sqlconnect move table: dropping old table after copying its contents to the new table")
)

type DB interface {
sqlDB
// SqlDB returns the underlying *sql.DB
SqlDB() *sql.DB
CatalogAdmin
SchemaAdmin
TableAdmin
JsonRowMapper
Expand Down Expand Up @@ -43,6 +47,12 @@ type sqlDB interface {
Stats() sql.DBStats
}

type CatalogAdmin interface {
// CurrentCatalog returns the current catalog.
// If this operation is not supported by the warehouse [ErrNotSupported] will be returned.
CurrentCatalog(ctx context.Context) (string, error)
}

type SchemaAdmin interface {
// CreateSchema creates a schema
CreateSchema(ctx context.Context, schema SchemaRef) error
Expand Down
15 changes: 15 additions & 0 deletions sqlconnect/internal/base/catalog_admin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package base

import (
"context"
"fmt"
)

// CurrentCatalog returns the current catalog
func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
var catalog string
if err := db.QueryRowContext(ctx, db.sqlCommands.CurrentCatalog()).Scan(&catalog); err != nil {
return "", fmt.Errorf("getting current catalog: %w", err)
}
return catalog, nil
}
15 changes: 12 additions & 3 deletions sqlconnect/internal/base/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ func NewDB(db *sql.DB, opts ...Option) *DB {
return value
},
sqlCommands: SQLCommands{
CurrentCatalog: func() string {
return "SELECT current_catalog"
},
CreateSchema: func(schema QuotedIdentifier) string {
return fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %[1]s", schema)
},
Expand All @@ -46,8 +49,12 @@ func NewDB(db *sql.DB, opts ...Option) *DB {
TableExists: func(schema, table UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table)
},
ListColumns: func(schema, table UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table), "column_name", "data_type"
ListColumns: func(catalog, schema, table UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table)
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
}
return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type"
},
CountTableRows: func(table QuotedIdentifier) string { return fmt.Sprintf("SELECT COUNT(*) FROM %[1]s", table) },
DropTable: func(table QuotedIdentifier) string { return fmt.Sprintf("DROP TABLE IF EXISTS %[1]s", table) },
Expand Down Expand Up @@ -101,6 +108,8 @@ type (
QuotedIdentifier string // A quoted identifier is a string that is quoted, e.g. "my_table"
UnquotedIdentifier string // An unquoted identifier is a string that is not quoted, e.g. my_table
SQLCommands struct {
// Provides the SQL command to get the current catalog
CurrentCatalog func() string
// Provides the SQL command to create a schema
CreateSchema func(schema QuotedIdentifier) string
// Provides the SQL command to list all schemas
Expand All @@ -118,7 +127,7 @@ type (
// Provides the SQL command to check if a table exists
TableExists func(schema, table UnquotedIdentifier) string
// Provides the SQL command to list all columns in a table along with the column names in the result set that point to the name and type
ListColumns func(schema, table UnquotedIdentifier) (sql, nameCol, typeCol string)
ListColumns func(catalog, schema, table UnquotedIdentifier) (sql, nameCol, typeCol string)
// Provides the SQL command to count the rows in a table
CountTableRows func(table QuotedIdentifier) string
// Provides the SQL command to drop a table
Expand Down
2 changes: 1 addition & 1 deletion sqlconnect/internal/base/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (db *DB) TableExists(ctx context.Context, relation sqlconnect.RelationRef)
// ListColumns returns a list of columns for the given table
func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) {
var res []sqlconnect.ColumnRef
stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name))
stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Catalog), UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name))
columns, err := db.QueryContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("querying list columns for %s: %w", relation.String(), err)
Expand Down
18 changes: 18 additions & 0 deletions sqlconnect/internal/bigquery/catalogadmin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package bigquery

import (
"context"

"cloud.google.com/go/bigquery"
)

func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
var catalog string
if err := db.WithBigqueryClient(ctx, func(c *bigquery.Client) error {
catalog = c.Project()
return nil
}); err != nil {
return "", err
}
return catalog, nil
}
8 changes: 6 additions & 2 deletions sqlconnect/internal/bigquery/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) {
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table)
}
cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table), "column_name", "data_type"
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table)
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
}
return stmt, "column_name", "data_type"
}

return cmds
Expand Down
3 changes: 0 additions & 3 deletions sqlconnect/internal/databricks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,5 @@ func (c *Config) Parse(configJson json.RawMessage) error {
if err != nil {
return err
}
if c.Catalog == "" {
c.Catalog = "hive_metastore" // default catalog
}
return nil
}
32 changes: 30 additions & 2 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"

databricks "github.com/databricks/databricks-sql-go"
"github.com/samber/lo"
Expand Down Expand Up @@ -44,13 +45,21 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
db := sql.OpenDB(connector)
db.SetConnMaxIdleTime(config.MaxConnIdleTime)

if _, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"); err != nil && !strings.Contains(err.Error(), "TABLE_OR_VIEW_NOT_FOUND") {
return nil, fmt.Errorf("checking if unity catalog is available: %w", err)
}
informationSchema := err == nil

return &DB{
DB: base.NewDB(
db,
base.WithDialect(dialect{}),
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT current_catalog()"
}
cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" }
cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) }

Expand All @@ -70,12 +79,30 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, table)
}
cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type"
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
if catalog == "" || !informationSchema {
return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type"
}
stmt := fmt.Sprintf(`SELECT
column_name,
data_type
FROM information_schema.columns
WHERE table_schema = '%[1]s'
AND table_name = '%[2]s'
AND table_catalog='%[3]s'
ORDER BY ORDINAL_POSITION ASC`,
schema,
table,
catalog)
return stmt, "column_name", "data_type"
}
cmds.RenameTable = func(schema, oldName, newName base.QuotedIdentifier) string {
return fmt.Sprintf("ALTER TABLE %[1]s.%[2]s RENAME TO %[1]s.%[3]s", schema, oldName, newName)
}
return cmds
}),
),
informationSchema: informationSchema,
}, nil
}

Expand All @@ -87,6 +114,7 @@ func init() {

type DB struct {
*base.DB
informationSchema bool
}

func getColumnTypeMapper(config Config) func(base.ColumnType) string {
Expand Down
24 changes: 23 additions & 1 deletion sqlconnect/internal/databricks/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/databricks"
integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test"
)
Expand All @@ -26,5 +27,26 @@ func TestDatabricksDB(t *testing.T) {
configJSON, err = sjson.Set(configJSON, "maxRetryWaitTime", 30*time.Second)
require.NoError(t, err, "failed to set maxRetryWaitTime")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
t.Run("with information schema", func(t *testing.T) {
configJSON, err := sjson.Set(configJSON, "catalog", "sqlconnect")
require.NoError(t, err, "failed to set catalog")
db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON))
require.NoError(t, err, "failed to create db")
_, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1")
require.NoError(t, err, "information schema should be available")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
})

t.Run("without information schema", func(t *testing.T) {
configJSON, err := sjson.Set(configJSON, "catalog", "hive_metastore")
require.NoError(t, err, "failed to set catalog")
db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON))
require.NoError(t, err, "failed to create db")
_, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1")
require.Error(t, err, "information schema should not be available")
require.ErrorContains(t, err, "TABLE_OR_VIEW_NOT_FOUND", "information schema should not be available")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
})
}
15 changes: 15 additions & 0 deletions sqlconnect/internal/databricks/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@ package databricks

import (
"context"
"fmt"
"strings"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

// ListColumns returns a list of columns for the given table
func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) {
if !db.informationSchema && relation.Catalog != "" {
currentCatalog, err := db.CurrentCatalog(ctx) // make sure the catalog matches the current catalog
if err != nil {
return nil, fmt.Errorf("getting current catalog: %w", err)
}
if relation.Catalog != currentCatalog {
return nil, fmt.Errorf("catalog %s not found", relation.Catalog)
}
}
return db.DB.ListColumns(ctx, relation)
}

// RenameTable in databricks falls back to MoveTable if rename is not supported
func (db *DB) RenameTable(ctx context.Context, oldRef, newRef sqlconnect.RelationRef) error {
if err := db.DB.RenameTable(ctx, oldRef, newRef); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
})
})

var currentCatalog string
t.Run("catalog admin", func(t *testing.T) {
t.Run("current catalog", func(t *testing.T) {
t.Run("with context cancelled", func(t *testing.T) {
_, err := db.CurrentCatalog(cancelledCtx)
require.Error(t, err, "it should not be able to get the current catalog with a cancelled context")
})

currentCatalog, err = db.CurrentCatalog(ctx)
if errors.Is(err, sqlconnect.ErrNotSupported) {
t.Skipf("skipping test for warehouse %s: %v", warehouse, err)
}
require.NoError(t, err, "it should be able to get the current catalog")
require.NotEmpty(t, currentCatalog, "it should return a non-empty current catalog")
})
})

t.Run("schema admin", func(t *testing.T) {
t.Run("schema doesn't exist", func(t *testing.T) {
exists, err := db.SchemaExists(ctx, schema)
Expand Down Expand Up @@ -178,18 +195,44 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
require.Error(t, err, "it should not be able to list columns with a cancelled context")
})

columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
t.Run("without catalog", func(t *testing.T) {
columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("with catalog", func(t *testing.T) {
tableWithCatalog := table
tableWithCatalog.Catalog = currentCatalog
columns, err := db.ListColumns(ctx, tableWithCatalog)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("with invalid catalog", func(t *testing.T) {
tableWithInvalidCatalog := table
tableWithInvalidCatalog.Catalog = "invalid"
cols, _ := db.ListColumns(ctx, tableWithInvalidCatalog)
require.Empty(t, cols, "it should return an empty list of columns for an invalid catalog")
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("list columns for sql query", func(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions sqlconnect/internal/mysql/catalogadmin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package mysql

import (
"context"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

// CurrentCatalog returns an error because it is not supported by MySQL
func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
return "", sqlconnect.ErrNotSupported
}
3 changes: 3 additions & 0 deletions sqlconnect/internal/mysql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func NewDB(configJSON json.RawMessage) (*DB, error) {
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT DATABASE()"
}
cmds.DropSchema = func(schema base.QuotedIdentifier) string { // mysql does not support CASCADE
return fmt.Sprintf("DROP SCHEMA %[1]s", schema)
}
Expand Down
3 changes: 3 additions & 0 deletions sqlconnect/internal/redshift/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) {
base.WithColumnTypeMappings(getColumnTypeMappings(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT current_database()"
}
cmds.ListSchemas = func() (string, string) {
return "SELECT schema_name FROM svv_redshift_schemas", "schema_name"
}
Expand Down
Loading

0 comments on commit 11bf6b2

Please sign in to comment.