Skip to content

Commit

Permalink
feat(databricks): normalize column names
Browse files Browse the repository at this point in the history
  • Loading branch information
atzoum committed Dec 12, 2024
1 parent a582787 commit 6f01029
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
6 changes: 6 additions & 0 deletions sqlconnect/internal/databricks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ type Config struct {
SessionParams map[string]string `json:"sessionParams"`

UseLegacyMappings bool `json:"useLegacyMappings"`
// SkipColumnNormalization skips normalizing column names during ListColumns and ListColumnsForSqlQuery.
// Databricks is returning column names case sensitive from information schema, even though it is case insensitive.
// So, by default table names are returned normalized by databricks, whereas column names are not.
// To avoid this inconsistency, we are normalizing column names by default.
// If you want to skip this normalization, set this flag to true.
SkipColumnNormalization bool `json:"skipColumnNormalisation"`
}

func (c *Config) Parse(input json.RawMessage) error {
Expand Down
6 changes: 4 additions & 2 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
return cmds
}),
),
informationSchema: informationSchema,
informationSchema: informationSchema,
skipColumnNormalization: config.SkipColumnNormalization,
}, nil
}

Expand All @@ -132,7 +133,8 @@ func init() {

type DB struct {
*base.DB
informationSchema bool
informationSchema bool
skipColumnNormalization bool
}

func getColumnTypeMapper(config Config) func(base.ColumnType) string {
Expand Down
22 changes: 21 additions & 1 deletion sqlconnect/internal/databricks/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"strings"

"github.com/samber/lo"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

Expand All @@ -19,7 +21,25 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef)
return nil, fmt.Errorf("catalog %s not found", relation.Catalog)
}
}
return db.DB.ListColumns(ctx, relation)
cols, err := db.DB.ListColumns(ctx, relation)
if db.skipColumnNormalization {
return cols, err
}

Check warning on line 27 in sqlconnect/internal/databricks/tableadmin.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/tableadmin.go#L26-L27

Added lines #L26 - L27 were not covered by tests
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconnect.ColumnRef, error) {
cols, err := db.DB.ListColumnsForSqlQuery(ctx, sql)
if db.skipColumnNormalization {
return cols, err
}

Check warning on line 38 in sqlconnect/internal/databricks/tableadmin.go

View check run for this annotation

Codecov / codecov/patch

sqlconnect/internal/databricks/tableadmin.go#L37-L38

Added lines #L37 - L38 were not covered by tests
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

// RenameTable in databricks falls back to MoveTable if rename is not supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package integrationtest

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -609,6 +610,75 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("list columns with mixed case", func(t *testing.T) {
unquotedColumn := "cOluMnA"
normalizedUnquotedColumn := db.NormaliseIdentifier(unquotedColumn)
quotedColumn := db.QuoteIdentifier("QuOted_CoLuMnB")
normalizedQuotedColumn := db.NormaliseIdentifier(quotedColumn)
parsedRel, err := db.ParseRelationRef(quotedColumn)
require.NoError(t, err, "it should be able to parse a quoted column")
normalizedQuotedColumnWithoutQuotes := parsedRel.Name

tableIdentifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("table_mixed_case")
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %[1]s (%[2]s int, %[3]s int)", tableIdentifier, unquotedColumn, quotedColumn))
require.NoErrorf(t, err, "it should be able to create a quoted table: %s", tableIdentifier)

table, err := db.ParseRelationRef(tableIdentifier)
require.NoError(t, err, "it should be able to parse a quoted table")

t.Run("without catalog", func(t *testing.T) {
columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")

var c1, c2 int
err = db.QueryRow(fmt.Sprintf("SELECT %[1]s, %[2]s FROM %[3]s", normalizedUnquotedColumn, normalizedQuotedColumn, tableIdentifier)).Scan(&c1, &c2)
require.ErrorIs(t, err, sql.ErrNoRows, "it should get a no rows error (supports normalised column names)")
})

t.Run("with catalog", func(t *testing.T) {
tableWithCatalog := table
tableWithCatalog.Catalog = currentCatalog
columns, err := db.ListColumns(ctx, tableWithCatalog)
require.NoErrorf(t, err, "it should be able to list columns for %s", tableWithCatalog)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})

require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})

t.Run("for sql query", func(t *testing.T) {
columns, err := db.ListColumnsForSqlQuery(ctx, "SELECT * FROM "+tableIdentifier)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})
})
})

t.Run("list columns for sql query", func(t *testing.T) {
Expand Down

0 comments on commit 6f01029

Please sign in to comment.