From 8161d2062f304d1ef3b37821dd5eea830816c8cf Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Fri, 13 Dec 2024 12:57:03 +0200 Subject: [PATCH] feat(databricks): normalize column names (#241) # Description The default behaviour of databricks is to return column names case sensitive when one queries the information schema, even though it treats them as case insensitive in sql statements. For table names, databricks behaviour is different, it stores them in lowercase. So, by default, table names are returned normalized by databricks, whereas column names are not. To avoid this inconsistency, we are now normalizing column names as default behaviour. Users who wish to avoid this behavioural change, they can set the `SkipColumnNormalization` configuration flag to `true`. ## Linear Ticket resolves PRO-3980 ## Security - [x] The code changed/added as part of this pull request won't create any security issues with how the software is being used. --- sqlconnect/internal/databricks/config.go | 6 ++ sqlconnect/internal/databricks/db.go | 6 +- sqlconnect/internal/databricks/tableadmin.go | 22 +++++- .../db_integration_test_scenario.go | 70 +++++++++++++++++++ 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/sqlconnect/internal/databricks/config.go b/sqlconnect/internal/databricks/config.go index 4b10ac7..09c872e 100644 --- a/sqlconnect/internal/databricks/config.go +++ b/sqlconnect/internal/databricks/config.go @@ -27,6 +27,12 @@ type Config struct { SessionParams map[string]string `json:"sessionParams"` UseLegacyMappings bool `json:"useLegacyMappings"` + // SkipColumnNormalization skips normalizing column names during ListColumns and ListColumnsForSqlQuery. + // Databricks is returning column names case sensitive from information schema, even though it is case insensitive. + // So, by default table names are returned normalized by databricks, whereas column names are not. + // To avoid this inconsistency, we are normalizing column names by default. + // If you want to skip this normalization, set this flag to true. + SkipColumnNormalization bool `json:"skipColumnNormalisation"` } func (c *Config) Parse(input json.RawMessage) error { diff --git a/sqlconnect/internal/databricks/db.go b/sqlconnect/internal/databricks/db.go index cd781f6..d1349c5 100644 --- a/sqlconnect/internal/databricks/db.go +++ b/sqlconnect/internal/databricks/db.go @@ -120,7 +120,8 @@ func NewDB(configJson json.RawMessage) (*DB, error) { return cmds }), ), - informationSchema: informationSchema, + informationSchema: informationSchema, + skipColumnNormalization: config.SkipColumnNormalization, }, nil } @@ -132,7 +133,8 @@ func init() { type DB struct { *base.DB - informationSchema bool + informationSchema bool + skipColumnNormalization bool } func getColumnTypeMapper(config Config) func(base.ColumnType) string { diff --git a/sqlconnect/internal/databricks/tableadmin.go b/sqlconnect/internal/databricks/tableadmin.go index d43c050..1fabf6d 100644 --- a/sqlconnect/internal/databricks/tableadmin.go +++ b/sqlconnect/internal/databricks/tableadmin.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" + "github.com/samber/lo" + "github.com/rudderlabs/sqlconnect-go/sqlconnect" ) @@ -19,7 +21,25 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) return nil, fmt.Errorf("catalog %s not found", relation.Catalog) } } - return db.DB.ListColumns(ctx, relation) + cols, err := db.DB.ListColumns(ctx, relation) + if db.skipColumnNormalization { + return cols, err + } + return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + col.Name = db.NormaliseIdentifier(col.Name) + return col + }), err +} + +func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconnect.ColumnRef, error) { + cols, err := db.DB.ListColumnsForSqlQuery(ctx, sql) + if db.skipColumnNormalization { + return cols, err + } + return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + col.Name = db.NormaliseIdentifier(col.Name) + return col + }), err } // RenameTable in databricks falls back to MoveTable if rename is not supported diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index 036d847..1cc5c52 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -2,6 +2,7 @@ package integrationtest import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -609,6 +610,75 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe {Name: formatfn("c2"), Type: "string"}, }, "it should return the correct columns") }) + + t.Run("list columns with mixed case", func(t *testing.T) { + unquotedColumn := "cOluMnA" + normalizedUnquotedColumn := db.NormaliseIdentifier(unquotedColumn) + quotedColumn := db.QuoteIdentifier("QuOted_CoLuMnB") + normalizedQuotedColumn := db.NormaliseIdentifier(quotedColumn) + parsedRel, err := db.ParseRelationRef(quotedColumn) + require.NoError(t, err, "it should be able to parse a quoted column") + normalizedQuotedColumnWithoutQuotes := parsedRel.Name + + tableIdentifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("table_mixed_case") + _, err = db.Exec(fmt.Sprintf("CREATE TABLE %[1]s (%[2]s int, %[3]s int)", tableIdentifier, unquotedColumn, quotedColumn)) + require.NoErrorf(t, err, "it should be able to create a quoted table: %s", tableIdentifier) + + table, err := db.ParseRelationRef(tableIdentifier) + require.NoError(t, err, "it should be able to parse a quoted table") + + t.Run("without catalog", func(t *testing.T) { + columns, err := db.ListColumns(ctx, table) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + + var c1, c2 int + err = db.QueryRow(fmt.Sprintf("SELECT %[1]s, %[2]s FROM %[3]s", normalizedUnquotedColumn, normalizedQuotedColumn, tableIdentifier)).Scan(&c1, &c2) + require.ErrorIs(t, err, sql.ErrNoRows, "it should get a no rows error (supports normalised column names)") + }) + + t.Run("with catalog", func(t *testing.T) { + tableWithCatalog := table + tableWithCatalog.Catalog = currentCatalog + columns, err := db.ListColumns(ctx, tableWithCatalog) + require.NoErrorf(t, err, "it should be able to list columns for %s", tableWithCatalog) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + }) + + t.Run("for sql query", func(t *testing.T) { + columns, err := db.ListColumnsForSqlQuery(ctx, "SELECT * FROM "+tableIdentifier) + columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef { + require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name) + col.RawType = "" + return col + }) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: normalizedUnquotedColumn, Type: "int"}, + {Name: normalizedQuotedColumnWithoutQuotes, Type: "int"}, + }, "it should return the correct columns") + }) + }) }) t.Run("list columns for sql query", func(t *testing.T) {