Skip to content

Commit

Permalink
feat(databricks): normalize column names (#241)
Browse files Browse the repository at this point in the history
# Description

The default behaviour of databricks is to return column names case
sensitive when one queries the information schema, even though it treats
them as case insensitive in sql statements.
For table names, databricks behaviour is different, it stores them in
lowercase.
So, by default, table names are returned normalized by databricks,
whereas column names are not.
To avoid this inconsistency, we are now normalizing column names as
default behaviour.

Users who wish to avoid this behavioural change, they can set the
`SkipColumnNormalization` configuration flag to `true`.

## Linear Ticket

resolves PRO-3980

## Security

- [x] The code changed/added as part of this pull request won't create
any security issues with how the software is being used.
  • Loading branch information
atzoum authored Dec 13, 2024
1 parent 5757306 commit 8161d20
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
6 changes: 6 additions & 0 deletions sqlconnect/internal/databricks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ type Config struct {
SessionParams map[string]string `json:"sessionParams"`

UseLegacyMappings bool `json:"useLegacyMappings"`
// SkipColumnNormalization skips normalizing column names during ListColumns and ListColumnsForSqlQuery.
// Databricks is returning column names case sensitive from information schema, even though it is case insensitive.
// So, by default table names are returned normalized by databricks, whereas column names are not.
// To avoid this inconsistency, we are normalizing column names by default.
// If you want to skip this normalization, set this flag to true.
SkipColumnNormalization bool `json:"skipColumnNormalisation"`
}

func (c *Config) Parse(input json.RawMessage) error {
Expand Down
6 changes: 4 additions & 2 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
return cmds
}),
),
informationSchema: informationSchema,
informationSchema: informationSchema,
skipColumnNormalization: config.SkipColumnNormalization,
}, nil
}

Expand All @@ -132,7 +133,8 @@ func init() {

type DB struct {
*base.DB
informationSchema bool
informationSchema bool
skipColumnNormalization bool
}

func getColumnTypeMapper(config Config) func(base.ColumnType) string {
Expand Down
22 changes: 21 additions & 1 deletion sqlconnect/internal/databricks/tableadmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"strings"

"github.com/samber/lo"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
)

Expand All @@ -19,7 +21,25 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef)
return nil, fmt.Errorf("catalog %s not found", relation.Catalog)
}
}
return db.DB.ListColumns(ctx, relation)
cols, err := db.DB.ListColumns(ctx, relation)
if db.skipColumnNormalization {
return cols, err
}
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconnect.ColumnRef, error) {
cols, err := db.DB.ListColumnsForSqlQuery(ctx, sql)
if db.skipColumnNormalization {
return cols, err
}
return lo.Map(cols, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
col.Name = db.NormaliseIdentifier(col.Name)
return col
}), err
}

// RenameTable in databricks falls back to MoveTable if rename is not supported
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package integrationtest

import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -609,6 +610,75 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
{Name: formatfn("c2"), Type: "string"},
}, "it should return the correct columns")
})

t.Run("list columns with mixed case", func(t *testing.T) {
unquotedColumn := "cOluMnA"
normalizedUnquotedColumn := db.NormaliseIdentifier(unquotedColumn)
quotedColumn := db.QuoteIdentifier("QuOted_CoLuMnB")
normalizedQuotedColumn := db.NormaliseIdentifier(quotedColumn)
parsedRel, err := db.ParseRelationRef(quotedColumn)
require.NoError(t, err, "it should be able to parse a quoted column")
normalizedQuotedColumnWithoutQuotes := parsedRel.Name

tableIdentifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("table_mixed_case")
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %[1]s (%[2]s int, %[3]s int)", tableIdentifier, unquotedColumn, quotedColumn))
require.NoErrorf(t, err, "it should be able to create a quoted table: %s", tableIdentifier)

table, err := db.ParseRelationRef(tableIdentifier)
require.NoError(t, err, "it should be able to parse a quoted table")

t.Run("without catalog", func(t *testing.T) {
columns, err := db.ListColumns(ctx, table)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")

var c1, c2 int
err = db.QueryRow(fmt.Sprintf("SELECT %[1]s, %[2]s FROM %[3]s", normalizedUnquotedColumn, normalizedQuotedColumn, tableIdentifier)).Scan(&c1, &c2)
require.ErrorIs(t, err, sql.ErrNoRows, "it should get a no rows error (supports normalised column names)")
})

t.Run("with catalog", func(t *testing.T) {
tableWithCatalog := table
tableWithCatalog.Catalog = currentCatalog
columns, err := db.ListColumns(ctx, tableWithCatalog)
require.NoErrorf(t, err, "it should be able to list columns for %s", tableWithCatalog)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})

require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})

t.Run("for sql query", func(t *testing.T) {
columns, err := db.ListColumnsForSqlQuery(ctx, "SELECT * FROM "+tableIdentifier)
columns = lo.Map(columns, func(col sqlconnect.ColumnRef, _ int) sqlconnect.ColumnRef {
require.NotEmptyf(t, col.RawType, "it should return the raw type for column %q", col.Name)
col.RawType = ""
return col
})
require.NoError(t, err, "it should be able to list columns")
require.Len(t, columns, 2, "it should return the correct number of columns")
require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{
{Name: normalizedUnquotedColumn, Type: "int"},
{Name: normalizedQuotedColumnWithoutQuotes, Type: "int"},
}, "it should return the correct columns")
})
})
})

t.Run("list columns for sql query", func(t *testing.T) {
Expand Down

0 comments on commit 8161d20

Please sign in to comment.