Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: respect catalog restriction in ListColumns #16

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion sqlconnect/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ import (
"time"
)

var ErrDropOldTablePostCopy = errors.New("move table: dropping old table after copying its contents to the new table")
var (
ErrNotSupported = errors.New("sqconnect: feature not supported")
ErrDropOldTablePostCopy = errors.New("sqlconnect move table: dropping old table after copying its contents to the new table")
)

type DB interface {
sqlDB
// SqlDB returns the underlying *sql.DB
SqlDB() *sql.DB
CatalogAdmin
SchemaAdmin
TableAdmin
JsonRowMapper
Expand Down Expand Up @@ -43,6 +47,12 @@ type sqlDB interface {
Stats() sql.DBStats
}

type CatalogAdmin interface {
// CurrentCatalog returns the current catalog.
// If this operation is not supported by the warehouse [ErrNotSupported] will be returned.
CurrentCatalog(ctx context.Context) (string, error)
}

type SchemaAdmin interface {
// CreateSchema creates a schema
CreateSchema(ctx context.Context, schema SchemaRef) error
Expand Down
15 changes: 15 additions & 0 deletions sqlconnect/internal/base/catalog_admin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package base

import (
"context"
"fmt"
)

// CurrentCatalog returns the current catalog
func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
var catalog string
if err := db.QueryRowContext(ctx, db.sqlCommands.CurrentCatalog()).Scan(&catalog); err != nil {
return "", fmt.Errorf("getting current catalog: %w", err)
}
return catalog, nil
}
15 changes: 12 additions & 3 deletions sqlconnect/internal/base/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ func NewDB(db *sql.DB, opts ...Option) *DB {
return value
},
sqlCommands: SQLCommands{
CurrentCatalog: func() string {
return "SELECT current_catalog"
},
CreateSchema: func(schema QuotedIdentifier) string {
return fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %[1]s", schema)
},
Expand All @@ -46,8 +49,12 @@ func NewDB(db *sql.DB, opts ...Option) *DB {
TableExists: func(schema, table UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table)
},
ListColumns: func(schema, table UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table), "column_name", "data_type"
ListColumns: func(catalog, schema, table UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table)
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
}
return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type"
},
CountTableRows: func(table QuotedIdentifier) string { return fmt.Sprintf("SELECT COUNT(*) FROM %[1]s", table) },
DropTable: func(table QuotedIdentifier) string { return fmt.Sprintf("DROP TABLE IF EXISTS %[1]s", table) },
Expand Down Expand Up @@ -101,6 +108,8 @@ type (
QuotedIdentifier string // A quoted identifier is a string that is quoted, e.g. "my_table"
UnquotedIdentifier string // An unquoted identifier is a string that is not quoted, e.g. my_table
SQLCommands struct {
// Provides the SQL command to get the current catalog
CurrentCatalog func() string
// Provides the SQL command to create a schema
CreateSchema func(schema QuotedIdentifier) string
// Provides the SQL command to list all schemas
Expand All @@ -118,7 +127,7 @@ type (
// Provides the SQL command to check if a table exists
TableExists func(schema, table UnquotedIdentifier) string
// Provides the SQL command to list all columns in a table along with the column names in the result set that point to the name and type
ListColumns func(schema, table UnquotedIdentifier) (sql, nameCol, typeCol string)
ListColumns func(catalog, schema, table UnquotedIdentifier) (sql, nameCol, typeCol string)
// Provides the SQL command to count the rows in a table
CountTableRows func(table QuotedIdentifier) string
// Provides the SQL command to drop a table
Expand Down
2 changes: 1 addition & 1 deletion sqlconnect/internal/base/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (db *DB) TableExists(ctx context.Context, relation sqlconnect.RelationRef)
// ListColumns returns a list of columns for the given table
func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) {
var res []sqlconnect.ColumnRef
stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name))
stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Catalog), UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name))
columns, err := db.QueryContext(ctx, stmt)
if err != nil {
return nil, fmt.Errorf("querying list columns for %s: %w", relation.String(), err)
Expand Down
18 changes: 18 additions & 0 deletions sqlconnect/internal/bigquery/catalogadmin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package bigquery

import (
"context"

"cloud.google.com/go/bigquery"
)

func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
var catalog string
if err := db.WithBigqueryClient(ctx, func(c *bigquery.Client) error {
catalog = c.Project()
return nil
}); err != nil {
return "", err
}
return catalog, nil
}
8 changes: 6 additions & 2 deletions sqlconnect/internal/bigquery/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) {
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table)
}
cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table), "column_name", "data_type"
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table)
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
}
return stmt, "column_name", "data_type"
}

return cmds
Expand Down
3 changes: 0 additions & 3 deletions sqlconnect/internal/databricks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,5 @@ func (c *Config) Parse(configJson json.RawMessage) error {
if err != nil {
return err
}
if c.Catalog == "" {
c.Catalog = "hive_metastore" // default catalog
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: rudder-sources now sets this as a default value, the library assumes no default value

}
return nil
}
32 changes: 30 additions & 2 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"database/sql"
"encoding/json"
"fmt"
"strings"

databricks "github.com/databricks/databricks-sql-go"
"github.com/samber/lo"
Expand Down Expand Up @@ -44,13 +45,21 @@
db := sql.OpenDB(connector)
db.SetConnMaxIdleTime(config.MaxConnIdleTime)

if _, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"); err != nil && !strings.Contains(err.Error(), "TABLE_OR_VIEW_NOT_FOUND") {
return nil, fmt.Errorf("checking if unity catalog is available: %w", err)

Check warning on line 49 in sqlconnect/internal/databricks/db.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/db.go#L49

Added line #L49 was not covered by tests
}
informationSchema := err == nil

return &DB{
DB: base.NewDB(
db,
base.WithDialect(dialect{}),
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT current_catalog()"
}
cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" }
cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) }

Expand All @@ -70,12 +79,30 @@
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, table)
}
cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) {
return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type"
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
if catalog == "" || !informationSchema {
return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type"
}
stmt := fmt.Sprintf(`SELECT
column_name,
data_type
FROM information_schema.columns
WHERE table_schema = '%[1]s'
AND table_name = '%[2]s'
AND table_catalog='%[3]s'
ORDER BY ORDINAL_POSITION ASC`,
schema,
table,
catalog)
return stmt, "column_name", "data_type"
}
cmds.RenameTable = func(schema, oldName, newName base.QuotedIdentifier) string {
return fmt.Sprintf("ALTER TABLE %[1]s.%[2]s RENAME TO %[1]s.%[3]s", schema, oldName, newName)
}
return cmds
}),
),
informationSchema: informationSchema,
}, nil
}

Expand All @@ -87,6 +114,7 @@

type DB struct {
*base.DB
informationSchema bool
}

func getColumnTypeMapper(config Config) func(base.ColumnType) string {
Expand Down
24 changes: 23 additions & 1 deletion sqlconnect/internal/databricks/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
"github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/databricks"
integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test"
)
Expand All @@ -26,5 +27,26 @@ func TestDatabricksDB(t *testing.T) {
configJSON, err = sjson.Set(configJSON, "maxRetryWaitTime", 30*time.Second)
require.NoError(t, err, "failed to set maxRetryWaitTime")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
t.Run("with information schema", func(t *testing.T) {
configJSON, err := sjson.Set(configJSON, "catalog", "sqlconnect")
require.NoError(t, err, "failed to set catalog")
db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON))
require.NoError(t, err, "failed to create db")
_, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1")
require.NoError(t, err, "information schema should be available")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
})

t.Run("without information schema", func(t *testing.T) {
configJSON, err := sjson.Set(configJSON, "catalog", "hive_metastore")
require.NoError(t, err, "failed to set catalog")
db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON))
require.NoError(t, err, "failed to create db")
_, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1")
require.Error(t, err, "information schema should not be available")
require.ErrorContains(t, err, "TABLE_OR_VIEW_NOT_FOUND", "information schema should not be available")

integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true})
})
}
15 changes: 15 additions & 0 deletions sqlconnect/internal/databricks/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@

import (
"context"
"fmt"
"strings"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

// ListColumns returns a list of columns for the given table
func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) {
if !db.informationSchema && relation.Catalog != "" {
currentCatalog, err := db.CurrentCatalog(ctx) // make sure the catalog matches the current catalog
if err != nil {
return nil, fmt.Errorf("getting current catalog: %w", err)

Check warning on line 16 in sqlconnect/internal/databricks/tableadmin.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/tableadmin.go#L16

Added line #L16 was not covered by tests
}
if relation.Catalog != currentCatalog {
return nil, fmt.Errorf("catalog %s not found", relation.Catalog)
}
}
return db.DB.ListColumns(ctx, relation)
}

// RenameTable in databricks falls back to MoveTable if rename is not supported
func (db *DB) RenameTable(ctx context.Context, oldRef, newRef sqlconnect.RelationRef) error {
if err := db.DB.RenameTable(ctx, oldRef, newRef); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
})
})

var currentCatalog string
t.Run("catalog admin", func(t *testing.T) {
t.Run("current catalog", func(t *testing.T) {
t.Run("with context cancelled", func(t *testing.T) {
_, err := db.CurrentCatalog(cancelledCtx)
require.Error(t, err, "it should not be able to get the current catalog with a cancelled context")
})

currentCatalog, err = db.CurrentCatalog(ctx)
if errors.Is(err, sqlconnect.ErrNotSupported) {
t.Skipf("skipping test for warehouse %s: %v", warehouse, err)
}
require.NoError(t, err, "it should be able to get the current catalog")
require.NotEmpty(t, currentCatalog, "it should return a non-empty current catalog")
})
})

t.Run("schema admin", func(t *testing.T) {
t.Run("schema doesn't exist", func(t *testing.T) {
exists, err := db.SchemaExists(ctx, schema)
Expand Down Expand Up @@ -178,18 +195,44 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
require.Error(t, err, "it should not be able to list columns with a cancelled context")
})

columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
t.Run("without catalog", func(t *testing.T) {
columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("with catalog", func(t *testing.T) {
tableWithCatalog := table
tableWithCatalog.Catalog = currentCatalog
columns, err := db.ListColumns(ctx, tableWithCatalog)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("with invalid catalog", func(t *testing.T) {
tableWithInvalidCatalog := table
tableWithInvalidCatalog.Catalog = "invalid"
cols, _ := db.ListColumns(ctx, tableWithInvalidCatalog)
require.Empty(t, cols, "it should return an empty list of columns for an invalid catalog")
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: formatfn("c1"), Type: "int"},
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("list columns for sql query", func(t *testing.T) {
Expand Down
12 changes: 12 additions & 0 deletions sqlconnect/internal/mysql/catalogadmin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package mysql

import (
"context"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

// CurrentCatalog returns an error because it is not supported by MySQL
func (db *DB) CurrentCatalog(ctx context.Context) (string, error) {
return "", sqlconnect.ErrNotSupported
}
3 changes: 3 additions & 0 deletions sqlconnect/internal/mysql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
base.WithColumnTypeMapper(getColumnTypeMapper(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT DATABASE()"

Check warning on line 42 in sqlconnect/internal/mysql/db.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/mysql/db.go#L42

Added line #L42 was not covered by tests
}
cmds.DropSchema = func(schema base.QuotedIdentifier) string { // mysql does not support CASCADE
return fmt.Sprintf("DROP SCHEMA %[1]s", schema)
}
Expand Down
3 changes: 3 additions & 0 deletions sqlconnect/internal/redshift/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) {
base.WithColumnTypeMappings(getColumnTypeMappings(config)),
base.WithJsonRowMapper(getJonRowMapper(config)),
base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands {
cmds.CurrentCatalog = func() string {
return "SELECT current_database()"
}
cmds.ListSchemas = func() (string, string) {
return "SELECT schema_name FROM svv_redshift_schemas", "schema_name"
}
Expand Down
Loading
Loading