From 11bf6b2efc9e566ecaba4780b4a47282ae7f5cb7 Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Thu, 21 Mar 2024 12:27:48 +0200 Subject: [PATCH] chore: respect catalog parameter in DB.ListColumns (#16) --- sqlconnect/db.go | 12 +++- sqlconnect/internal/base/catalog_admin.go | 15 +++++ sqlconnect/internal/base/db.go | 15 ++++- sqlconnect/internal/base/tableadmin.go | 2 +- sqlconnect/internal/bigquery/catalogadmin.go | 18 +++++ sqlconnect/internal/bigquery/db.go | 8 ++- sqlconnect/internal/databricks/config.go | 3 - sqlconnect/internal/databricks/db.go | 32 ++++++++- .../internal/databricks/integration_test.go | 24 ++++++- sqlconnect/internal/databricks/tableadmin.go | 15 +++++ .../db_integration_test_scenario.go | 65 +++++++++++++++---- sqlconnect/internal/mysql/catalogadmin.go | 12 ++++ sqlconnect/internal/mysql/db.go | 3 + sqlconnect/internal/redshift/db.go | 3 + sqlconnect/internal/snowflake/db.go | 8 ++- 15 files changed, 210 insertions(+), 25 deletions(-) create mode 100644 sqlconnect/internal/base/catalog_admin.go create mode 100644 sqlconnect/internal/bigquery/catalogadmin.go create mode 100644 sqlconnect/internal/mysql/catalogadmin.go diff --git a/sqlconnect/db.go b/sqlconnect/db.go index d63aeb8..aac2ddf 100644 --- a/sqlconnect/db.go +++ b/sqlconnect/db.go @@ -8,12 +8,16 @@ import ( "time" ) -var ErrDropOldTablePostCopy = errors.New("move table: dropping old table after copying its contents to the new table") +var ( + ErrNotSupported = errors.New("sqconnect: feature not supported") + ErrDropOldTablePostCopy = errors.New("sqlconnect move table: dropping old table after copying its contents to the new table") +) type DB interface { sqlDB // SqlDB returns the underlying *sql.DB SqlDB() *sql.DB + CatalogAdmin SchemaAdmin TableAdmin JsonRowMapper @@ -43,6 +47,12 @@ type sqlDB interface { Stats() sql.DBStats } +type CatalogAdmin interface { + // CurrentCatalog returns the current catalog. + // If this operation is not supported by the warehouse [ErrNotSupported] will be returned. + CurrentCatalog(ctx context.Context) (string, error) +} + type SchemaAdmin interface { // CreateSchema creates a schema CreateSchema(ctx context.Context, schema SchemaRef) error diff --git a/sqlconnect/internal/base/catalog_admin.go b/sqlconnect/internal/base/catalog_admin.go new file mode 100644 index 0000000..9c230f3 --- /dev/null +++ b/sqlconnect/internal/base/catalog_admin.go @@ -0,0 +1,15 @@ +package base + +import ( + "context" + "fmt" +) + +// CurrentCatalog returns the current catalog +func (db *DB) CurrentCatalog(ctx context.Context) (string, error) { + var catalog string + if err := db.QueryRowContext(ctx, db.sqlCommands.CurrentCatalog()).Scan(&catalog); err != nil { + return "", fmt.Errorf("getting current catalog: %w", err) + } + return catalog, nil +} diff --git a/sqlconnect/internal/base/db.go b/sqlconnect/internal/base/db.go index e074099..dbc4ed0 100644 --- a/sqlconnect/internal/base/db.go +++ b/sqlconnect/internal/base/db.go @@ -20,6 +20,9 @@ func NewDB(db *sql.DB, opts ...Option) *DB { return value }, sqlCommands: SQLCommands{ + CurrentCatalog: func() string { + return "SELECT current_catalog" + }, CreateSchema: func(schema QuotedIdentifier) string { return fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %[1]s", schema) }, @@ -46,8 +49,12 @@ func NewDB(db *sql.DB, opts ...Option) *DB { TableExists: func(schema, table UnquotedIdentifier) string { return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table) }, - ListColumns: func(schema, table UnquotedIdentifier) (string, string, string) { - return fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table), "column_name", "data_type" + ListColumns: func(catalog, schema, table UnquotedIdentifier) (string, string, string) { + stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table) + if catalog != "" { + stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog) + } + return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type" }, CountTableRows: func(table QuotedIdentifier) string { return fmt.Sprintf("SELECT COUNT(*) FROM %[1]s", table) }, DropTable: func(table QuotedIdentifier) string { return fmt.Sprintf("DROP TABLE IF EXISTS %[1]s", table) }, @@ -101,6 +108,8 @@ type ( QuotedIdentifier string // A quoted identifier is a string that is quoted, e.g. "my_table" UnquotedIdentifier string // An unquoted identifier is a string that is not quoted, e.g. my_table SQLCommands struct { + // Provides the SQL command to get the current catalog + CurrentCatalog func() string // Provides the SQL command to create a schema CreateSchema func(schema QuotedIdentifier) string // Provides the SQL command to list all schemas @@ -118,7 +127,7 @@ type ( // Provides the SQL command to check if a table exists TableExists func(schema, table UnquotedIdentifier) string // Provides the SQL command to list all columns in a table along with the column names in the result set that point to the name and type - ListColumns func(schema, table UnquotedIdentifier) (sql, nameCol, typeCol string) + ListColumns func(catalog, schema, table UnquotedIdentifier) (sql, nameCol, typeCol string) // Provides the SQL command to count the rows in a table CountTableRows func(table QuotedIdentifier) string // Provides the SQL command to drop a table diff --git a/sqlconnect/internal/base/tableadmin.go b/sqlconnect/internal/base/tableadmin.go index e1fa645..41ba257 100644 --- a/sqlconnect/internal/base/tableadmin.go +++ b/sqlconnect/internal/base/tableadmin.go @@ -133,7 +133,7 @@ func (db *DB) TableExists(ctx context.Context, relation sqlconnect.RelationRef) // ListColumns returns a list of columns for the given table func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) { var res []sqlconnect.ColumnRef - stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name)) + stmt, nameCol, typeCol := db.sqlCommands.ListColumns(UnquotedIdentifier(relation.Catalog), UnquotedIdentifier(relation.Schema), UnquotedIdentifier(relation.Name)) columns, err := db.QueryContext(ctx, stmt) if err != nil { return nil, fmt.Errorf("querying list columns for %s: %w", relation.String(), err) diff --git a/sqlconnect/internal/bigquery/catalogadmin.go b/sqlconnect/internal/bigquery/catalogadmin.go new file mode 100644 index 0000000..2a4ee11 --- /dev/null +++ b/sqlconnect/internal/bigquery/catalogadmin.go @@ -0,0 +1,18 @@ +package bigquery + +import ( + "context" + + "cloud.google.com/go/bigquery" +) + +func (db *DB) CurrentCatalog(ctx context.Context) (string, error) { + var catalog string + if err := db.WithBigqueryClient(ctx, func(c *bigquery.Client) error { + catalog = c.Project() + return nil + }); err != nil { + return "", err + } + return catalog, nil +} diff --git a/sqlconnect/internal/bigquery/db.go b/sqlconnect/internal/bigquery/db.go index 11a79a2..2660379 100644 --- a/sqlconnect/internal/bigquery/db.go +++ b/sqlconnect/internal/bigquery/db.go @@ -52,8 +52,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table) } - cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) { - return fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table), "column_name", "data_type" + cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { + stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table) + if catalog != "" { + stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog) + } + return stmt, "column_name", "data_type" } return cmds diff --git a/sqlconnect/internal/databricks/config.go b/sqlconnect/internal/databricks/config.go index db233e3..c259417 100644 --- a/sqlconnect/internal/databricks/config.go +++ b/sqlconnect/internal/databricks/config.go @@ -25,8 +25,5 @@ func (c *Config) Parse(configJson json.RawMessage) error { if err != nil { return err } - if c.Catalog == "" { - c.Catalog = "hive_metastore" // default catalog - } return nil } diff --git a/sqlconnect/internal/databricks/db.go b/sqlconnect/internal/databricks/db.go index 2311632..c3fa8b5 100644 --- a/sqlconnect/internal/databricks/db.go +++ b/sqlconnect/internal/databricks/db.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" databricks "github.com/databricks/databricks-sql-go" "github.com/samber/lo" @@ -44,6 +45,11 @@ func NewDB(configJson json.RawMessage) (*DB, error) { db := sql.OpenDB(connector) db.SetConnMaxIdleTime(config.MaxConnIdleTime) + if _, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"); err != nil && !strings.Contains(err.Error(), "TABLE_OR_VIEW_NOT_FOUND") { + return nil, fmt.Errorf("checking if unity catalog is available: %w", err) + } + informationSchema := err == nil + return &DB{ DB: base.NewDB( db, @@ -51,6 +57,9 @@ func NewDB(configJson json.RawMessage) (*DB, error) { base.WithColumnTypeMapper(getColumnTypeMapper(config)), base.WithJsonRowMapper(getJonRowMapper(config)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.CurrentCatalog = func() string { + return "SELECT current_catalog()" + } cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" } cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) } @@ -70,12 +79,30 @@ func NewDB(configJson json.RawMessage) (*DB, error) { cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, table) } - cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) { - return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type" + cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { + if catalog == "" || !informationSchema { + return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type" + } + stmt := fmt.Sprintf(`SELECT + column_name, + data_type + FROM information_schema.columns + WHERE table_schema = '%[1]s' + AND table_name = '%[2]s' + AND table_catalog='%[3]s' + ORDER BY ORDINAL_POSITION ASC`, + schema, + table, + catalog) + return stmt, "column_name", "data_type" + } + cmds.RenameTable = func(schema, oldName, newName base.QuotedIdentifier) string { + return fmt.Sprintf("ALTER TABLE %[1]s.%[2]s RENAME TO %[1]s.%[3]s", schema, oldName, newName) } return cmds }), ), + informationSchema: informationSchema, }, nil } @@ -87,6 +114,7 @@ func init() { type DB struct { *base.DB + informationSchema bool } func getColumnTypeMapper(config Config) func(base.ColumnType) string { diff --git a/sqlconnect/internal/databricks/integration_test.go b/sqlconnect/internal/databricks/integration_test.go index a4a43e1..c792f33 100644 --- a/sqlconnect/internal/databricks/integration_test.go +++ b/sqlconnect/internal/databricks/integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/sjson" + "github.com/rudderlabs/sqlconnect-go/sqlconnect" "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/databricks" integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" ) @@ -26,5 +27,26 @@ func TestDatabricksDB(t *testing.T) { configJSON, err = sjson.Set(configJSON, "maxRetryWaitTime", 30*time.Second) require.NoError(t, err, "failed to set maxRetryWaitTime") - integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + t.Run("with information schema", func(t *testing.T) { + configJSON, err := sjson.Set(configJSON, "catalog", "sqlconnect") + require.NoError(t, err, "failed to set catalog") + db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON)) + require.NoError(t, err, "failed to create db") + _, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1") + require.NoError(t, err, "information schema should be available") + + integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + }) + + t.Run("without information schema", func(t *testing.T) { + configJSON, err := sjson.Set(configJSON, "catalog", "hive_metastore") + require.NoError(t, err, "failed to set catalog") + db, err := sqlconnect.NewDB(databricks.DatabaseType, []byte(configJSON)) + require.NoError(t, err, "failed to create db") + _, err = db.Exec("SELECT * FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1") + require.Error(t, err, "information schema should not be available") + require.ErrorContains(t, err, "TABLE_OR_VIEW_NOT_FOUND", "information schema should not be available") + + integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + }) } diff --git a/sqlconnect/internal/databricks/tableadmin.go b/sqlconnect/internal/databricks/tableadmin.go index 6b130e7..d43c050 100644 --- a/sqlconnect/internal/databricks/tableadmin.go +++ b/sqlconnect/internal/databricks/tableadmin.go @@ -2,11 +2,26 @@ package databricks import ( "context" + "fmt" "strings" "github.com/rudderlabs/sqlconnect-go/sqlconnect" ) +// ListColumns returns a list of columns for the given table +func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) ([]sqlconnect.ColumnRef, error) { + if !db.informationSchema && relation.Catalog != "" { + currentCatalog, err := db.CurrentCatalog(ctx) // make sure the catalog matches the current catalog + if err != nil { + return nil, fmt.Errorf("getting current catalog: %w", err) + } + if relation.Catalog != currentCatalog { + return nil, fmt.Errorf("catalog %s not found", relation.Catalog) + } + } + return db.DB.ListColumns(ctx, relation) +} + // RenameTable in databricks falls back to MoveTable if rename is not supported func (db *DB) RenameTable(ctx context.Context, oldRef, newRef sqlconnect.RelationRef) error { if err := db.DB.RenameTable(ctx, oldRef, newRef); err != nil { diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index 11e8aa5..434c6c4 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -58,6 +58,23 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe }) }) + var currentCatalog string + t.Run("catalog admin", func(t *testing.T) { + t.Run("current catalog", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.CurrentCatalog(cancelledCtx) + require.Error(t, err, "it should not be able to get the current catalog with a cancelled context") + }) + + currentCatalog, err = db.CurrentCatalog(ctx) + if errors.Is(err, sqlconnect.ErrNotSupported) { + t.Skipf("skipping test for warehouse %s: %v", warehouse, err) + } + require.NoError(t, err, "it should be able to get the current catalog") + require.NotEmpty(t, currentCatalog, "it should return a non-empty current catalog") + }) + }) + t.Run("schema admin", func(t *testing.T) { t.Run("schema doesn't exist", func(t *testing.T) { exists, err := db.SchemaExists(ctx, schema) @@ -178,18 +195,44 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe require.Error(t, err, "it should not be able to list columns with a cancelled context") }) - columns, err := db.ListColumns(ctx, table) - columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { - require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) - col.RawType = "" - return col + t.Run("without catalog", func(t *testing.T) { + columns, err := db.ListColumns(ctx, table) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: formatfn("c1"), Type: "int"}, + {Name: formatfn("c2"), Type: "string"}, + }, "it should return the correct columns") + }) + + t.Run("with catalog", func(t *testing.T) { + tableWithCatalog := table + tableWithCatalog.Catalog = currentCatalog + columns, err := db.ListColumns(ctx, tableWithCatalog) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: formatfn("c1"), Type: "int"}, + {Name: formatfn("c2"), Type: "string"}, + }, "it should return the correct columns") + }) + + t.Run("with invalid catalog", func(t *testing.T) { + tableWithInvalidCatalog := table + tableWithInvalidCatalog.Catalog = "invalid" + cols, _ := db.ListColumns(ctx, tableWithInvalidCatalog) + require.Empty(t, cols, "it should return an empty list of columns for an invalid catalog") }) - require.NoError(t, err, "it should be able to list columns") - require.Len(t, columns, 2, "it should return the correct number of columns") - require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ - {Name: formatfn("c1"), Type: "int"}, - {Name: formatfn("c2"), Type: "string"}, - }, "it should return the correct columns") }) t.Run("list columns for sql query", func(t *testing.T) { diff --git a/sqlconnect/internal/mysql/catalogadmin.go b/sqlconnect/internal/mysql/catalogadmin.go new file mode 100644 index 0000000..a2b10f6 --- /dev/null +++ b/sqlconnect/internal/mysql/catalogadmin.go @@ -0,0 +1,12 @@ +package mysql + +import ( + "context" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +// CurrentCatalog returns an error because it is not supported by MySQL +func (db *DB) CurrentCatalog(ctx context.Context) (string, error) { + return "", sqlconnect.ErrNotSupported +} diff --git a/sqlconnect/internal/mysql/db.go b/sqlconnect/internal/mysql/db.go index 1c7de04..92a1f79 100644 --- a/sqlconnect/internal/mysql/db.go +++ b/sqlconnect/internal/mysql/db.go @@ -38,6 +38,9 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { base.WithColumnTypeMapper(getColumnTypeMapper(config)), base.WithJsonRowMapper(getJonRowMapper(config)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.CurrentCatalog = func() string { + return "SELECT DATABASE()" + } cmds.DropSchema = func(schema base.QuotedIdentifier) string { // mysql does not support CASCADE return fmt.Sprintf("DROP SCHEMA %[1]s", schema) } diff --git a/sqlconnect/internal/redshift/db.go b/sqlconnect/internal/redshift/db.go index 76cc833..36dc18e 100644 --- a/sqlconnect/internal/redshift/db.go +++ b/sqlconnect/internal/redshift/db.go @@ -35,6 +35,9 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { base.WithColumnTypeMappings(getColumnTypeMappings(config)), base.WithJsonRowMapper(getJonRowMapper(config)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.CurrentCatalog = func() string { + return "SELECT current_database()" + } cmds.ListSchemas = func() (string, string) { return "SELECT schema_name FROM svv_redshift_schemas", "schema_name" } diff --git a/sqlconnect/internal/snowflake/db.go b/sqlconnect/internal/snowflake/db.go index c7f6422..373333d 100644 --- a/sqlconnect/internal/snowflake/db.go +++ b/sqlconnect/internal/snowflake/db.go @@ -40,6 +40,9 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { base.WithColumnTypeMapper(getColumnTypeMapper(config)), base.WithJsonRowMapper(getJonRowMapper(config)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.CurrentCatalog = func() string { + return "SELECT current_database()" + } cmds.ListSchemas = func() (string, string) { return "SHOW TERSE SCHEMAS", "name" } cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf("SHOW TERSE SCHEMAS LIKE '%[1]s'", schema) @@ -57,7 +60,10 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { cmds.TableExists = func(schema, table base.UnquotedIdentifier) string { return fmt.Sprintf("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%[1]s' AND TABLE_NAME = '%[2]s'", schema, table) } - cmds.ListColumns = func(schema, table base.UnquotedIdentifier) (string, string, string) { + cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) { + if catalog != "" { + return fmt.Sprintf(`DESCRIBE TABLE "%[1]s"."%[2]s"."%[3]s"`, catalog, schema, table), "name", "type" + } return fmt.Sprintf(`DESCRIBE TABLE "%[1]s"."%[2]s"`, schema, table), "name", "type" } cmds.RenameTable = func(schema, oldName, newName base.QuotedIdentifier) string {