Skip to content

Commit

Permalink
fix: dialect quoting and normalisation inconsistencies (#143)
Browse files Browse the repository at this point in the history
# Description

- Escaping quotes in identifiers during `Dialect#QuoteIdentifier` for
all warehouses.
- Identifiers in `redshift`, `trino` and `databricks` are
case-insensitive and should be folded to lowercase regardless if they
are quoted or not.
- Handle case in `redshift` when `enable_case_sensitive_identifier: on`.
- Adding a scenario to verify that all dialects treat quoted tables
properly for all warehouses.

## Security

- [x] The code changed/added as part of this pull request won't create
any security issues with how the software is being used.
  • Loading branch information
atzoum authored Jul 24, 2024
1 parent 6b64477 commit 7a11657
Show file tree
Hide file tree
Showing 22 changed files with 377 additions and 57 deletions.
12 changes: 6 additions & 6 deletions sqlconnect/internal/base/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,29 @@ func NewDB(db *sql.DB, tunnelCloser func() error, opts ...Option) *DB {
return "SELECT schema_name FROM information_schema.schemata", "schema_name"
},
SchemaExists: func(schema UnquotedIdentifier) string {
return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", schema)
return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", EscapeSqlString(schema))
},
DropSchema: func(schema QuotedIdentifier) string { return fmt.Sprintf("DROP SCHEMA %[1]s CASCADE", schema) },
CreateTestTable: func(table QuotedIdentifier) string {
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 VARCHAR(255))", table)
},
ListTables: func(schema UnquotedIdentifier) []lo.Tuple2[string, string] {
return []lo.Tuple2[string, string]{
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", schema), B: "table_name"},
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", EscapeSqlString(schema)), B: "table_name"},
}
},
ListTablesWithPrefix: func(schema UnquotedIdentifier, prefix string) []lo.Tuple2[string, string] {
return []lo.Tuple2[string, string]{
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' AND table_name LIKE '%[2]s'", schema, prefix+"%"), B: "table_name"},
{A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' AND table_name LIKE '%[2]s'", EscapeSqlString(schema), prefix+"%"), B: "table_name"},
}
},
TableExists: func(schema, table UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table)
return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", EscapeSqlString(schema), EscapeSqlString(table))
},
ListColumns: func(catalog, schema, table UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table)
stmt := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", EscapeSqlString(schema), EscapeSqlString(table))
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", EscapeSqlString(catalog))
}
return stmt + " ORDER BY ordinal_position ASC", "column_name", "data_type"
},
Expand Down
7 changes: 6 additions & 1 deletion sqlconnect/internal/base/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {

// QuoteIdentifier quotes an identifier, e.g. a column name
func (d dialect) QuoteIdentifier(name string) string {
return fmt.Sprintf(`"%s"`, name)
return fmt.Sprintf(`"%s"`, strings.ReplaceAll(name, `"`, `""`))
}

// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
Expand Down Expand Up @@ -99,3 +99,8 @@ func doNormaliseIdentifier(identifier string, quote rune, normF func(string) str
}
return result.String()
}

// EscapeSqlString escapes a string for use in SQL, e.g. by doubling single quotes
func EscapeSqlString(value UnquotedIdentifier) string {
return strings.ReplaceAll(string(value), "'", "''")
}
6 changes: 3 additions & 3 deletions sqlconnect/internal/bigquery/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) {
}
}
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table)
return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, base.EscapeSqlString(table))
}
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table)
stmt := fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, base.EscapeSqlString(table))
if catalog != "" {
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", catalog)
stmt += fmt.Sprintf(" AND table_catalog = '%[1]s'", base.EscapeSqlString(catalog))
}
return stmt, "column_name", "data_type"
}
Expand Down
26 changes: 23 additions & 3 deletions sqlconnect/internal/bigquery/dialect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bigquery

import (
"regexp"
"strings"

"github.com/rudderlabs/sqlconnect-go/sqlconnect"
Expand All @@ -9,6 +10,11 @@ import (

type dialect struct{}

var (
escape = regexp.MustCompile("('|\"|`)")
unescape = regexp.MustCompile("\\\\('|\")")
)

// QuoteTable quotes a table name
func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {
if table.Schema != "" {
Expand All @@ -19,7 +25,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {

// QuoteIdentifier quotes an identifier, e.g. a column name
func (d dialect) QuoteIdentifier(name string) string {
return "`" + name + "`"
return "`" + escape.ReplaceAllString(name, "\\$1") + "`"
}

// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
Expand All @@ -31,11 +37,25 @@ var identityFn = func(s string) string { return s }

// NormaliseIdentifier normalises identifier parts that are unquoted, typically by lower or upper casing them, depending on the database
func (d dialect) NormaliseIdentifier(identifier string) string {
return base.NormaliseIdentifier(identifier, '`', identityFn)
return escapeSpecial(base.NormaliseIdentifier(unescapeSpecial(identifier), '`', identityFn))
}

// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes.
// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema.
func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) {
return base.ParseRelationRef(identifier, '`', identityFn)
return base.ParseRelationRef(unescapeSpecial(identifier), '`', identityFn)
}

// unescapeSpecial unescapes special characters in an identifier and replaces escaped backticks with a double backtick
func unescapeSpecial(identifier string) string {
identifier = strings.ReplaceAll(identifier, "\\`", "``")
return unescape.ReplaceAllString(identifier, "$1")
}

// escapeSpecial escapes special characters in an identifier and replaces double backticks with an escaped backtick
func escapeSpecial(identifier string) string {
identifier = strings.ReplaceAll(identifier, "``", "\\`")
identifier = strings.ReplaceAll(identifier, "'", "\\'")
identifier = strings.ReplaceAll(identifier, "\"", "\\\"")
return identifier
}
10 changes: 6 additions & 4 deletions sqlconnect/internal/bigquery/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ func TestDialect(t *testing.T) {
t.Run("quote identifier", func(t *testing.T) {
quoted := d.QuoteIdentifier("column")
require.Equal(t, "`column`", quoted, "column name should be quoted with backticks")

require.Equal(t, "`col\\`umn`", d.QuoteIdentifier("col`umn"), "column name with backtick should be escaped")
})

t.Run("quote table", func(t *testing.T) {
Expand All @@ -41,8 +43,8 @@ func TestDialect(t *testing.T) {
normalised = d.NormaliseIdentifier("TaBle.`ColUmn`")
require.Equal(t, "TaBle.`ColUmn`", normalised, "non quoted parts should be normalised")

normalised = d.NormaliseIdentifier("`Sh``EmA`.TABLE.`ColUmn`")
require.Equal(t, "`Sh``EmA`.TABLE.`ColUmn`", normalised, "non quoted parts should be normalised")
normalised = d.NormaliseIdentifier("`Sh\\`EmA`.TABLE.`Co\\'lUmn`")
require.Equal(t, "`Sh\\`EmA`.TABLE.`Co\\'lUmn`", normalised, "non quoted parts should be normalised")
})

t.Run("parse relation", func(t *testing.T) {
Expand All @@ -62,8 +64,8 @@ func TestDialect(t *testing.T) {
require.NoError(t, err)
require.Equal(t, sqlconnect.RelationRef{Schema: "ScHeMA", Name: "TaBle"}, parsed)

parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaBle`")
parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaB\\`\\\"\\'le`")
require.NoError(t, err)
require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "ScHeMA", Name: "TaBle"}, parsed)
require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "ScHeMA", Name: "TaB`\"'le"}, parsed)
})
}
3 changes: 2 additions & 1 deletion sqlconnect/internal/bigquery/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ func TestBigqueryDB(t *testing.T) {
[]byte(configJSON),
strings.ToLower,
integrationtest.Options{
LegacySupport: true,
LegacySupport: true,
SpecialCharactersInQuotedTable: "-",
},
)
}
8 changes: 5 additions & 3 deletions sqlconnect/internal/databricks/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,16 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
return "SELECT current_catalog()"
}
cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" }
cmds.SchemaExists = func(schema base.UnquotedIdentifier) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) }
cmds.SchemaExists = func(schema base.UnquotedIdentifier) string {
return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, base.EscapeSqlString(schema))
}

cmds.CreateTestTable = func(table base.QuotedIdentifier) string {
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 STRING)", table)
}
cmds.ListTables = func(schema base.UnquotedIdentifier) []lo.Tuple2[string, string] {
return []lo.Tuple2[string, string]{
{A: fmt.Sprintf("SHOW TABLES IN `%s`", schema), B: "tableName"},
{A: fmt.Sprintf("SHOW TABLES IN `%s`", base.EscapeSqlString(schema)), B: "tableName"},
}
}
cmds.ListTablesWithPrefix = func(schema base.UnquotedIdentifier, prefix string) []lo.Tuple2[string, string] {
Expand All @@ -93,7 +95,7 @@ func NewDB(configJson json.RawMessage) (*DB, error) {
}
}
cmds.TableExists = func(schema, table base.UnquotedIdentifier) string {
return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, table)
return fmt.Sprintf("SHOW TABLES IN `%[1]s` LIKE '%[2]s'", schema, base.EscapeSqlString(table))
}
cmds.ListColumns = func(catalog, schema, table base.UnquotedIdentifier) (string, string, string) {
if catalog == "" || !informationSchema {
Expand Down
12 changes: 8 additions & 4 deletions sqlconnect/internal/databricks/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {

// QuoteIdentifier quotes an identifier, e.g. a column name
func (d dialect) QuoteIdentifier(name string) string {
return "`" + name + "`"
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
}

// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
func (d dialect) FormatTableName(name string) string {
return strings.ToLower(name)
}

// NormaliseIdentifier normalises identifier parts that are unquoted, typically by lower or upper casing them, depending on the database
// NormaliseIdentifier normalises all identifier parts by lower casing them.
func (d dialect) NormaliseIdentifier(identifier string) string {
return base.NormaliseIdentifier(identifier, '`', strings.ToLower)
// Identifiers are case-insensitive
// https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html#:~:text=Identifiers%20are%20case%2Dinsensitive
// Unity Catalog stores all object names as lowercase
// https://docs.databricks.com/en/sql/language-manual/sql-ref-names.html#:~:text=Unity%20Catalog%20stores%20all%20object%20names%20as%20lowercase
return strings.ToLower(identifier)
}

// ParseRelationRef parses a string into a RelationRef after normalising the identifier and stripping out surrounding quotes.
// The result is a RelationRef with case-sensitive fields, i.e. it can be safely quoted (see [QuoteTable] and, for instance, used for matching against the database's information schema.
func (d dialect) ParseRelationRef(identifier string) (sqlconnect.RelationRef, error) {
return base.ParseRelationRef(identifier, '`', strings.ToLower)
return base.ParseRelationRef(strings.ToLower(identifier), '`', strings.ToLower)
}
12 changes: 6 additions & 6 deletions sqlconnect/internal/databricks/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func TestDialect(t *testing.T) {
require.Equal(t, "column", normalised, "column name should be normalised to lowercase")

normalised = d.NormaliseIdentifier("`ColUmn`")
require.Equal(t, "`ColUmn`", normalised, "quoted column name should not be normalised")
require.Equal(t, "`column`", normalised, "quoted column name should be normalised to lowercase")

normalised = d.NormaliseIdentifier("TaBle.`ColUmn`")
require.Equal(t, "table.`ColUmn`", normalised, "non quoted parts should be normalised")
require.Equal(t, "table.`column`", normalised, "all parts should be normalised")

normalised = d.NormaliseIdentifier("`Sh``EmA`.TABLE.`ColUmn`")
require.Equal(t, "`Sh``EmA`.table.`ColUmn`", normalised, "non quoted parts should be normalised")
require.Equal(t, "`sh``ema`.table.`column`", normalised, "all parts should be normalised to lowercase")
})

t.Run("parse relation", func(t *testing.T) {
Expand All @@ -56,14 +56,14 @@ func TestDialect(t *testing.T) {

parsed, err = d.ParseRelationRef("`TaBle`")
require.NoError(t, err)
require.Equal(t, sqlconnect.RelationRef{Name: "TaBle"}, parsed)
require.Equal(t, sqlconnect.RelationRef{Name: "table"}, parsed)

parsed, err = d.ParseRelationRef("ScHeMA.`TaBle`")
require.NoError(t, err)
require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "TaBle"}, parsed)
require.Equal(t, sqlconnect.RelationRef{Schema: "schema", Name: "table"}, parsed)

parsed, err = d.ParseRelationRef("`CaTa``LoG`.ScHeMA.`TaBle`")
require.NoError(t, err)
require.Equal(t, sqlconnect.RelationRef{Catalog: "CaTa`LoG", Schema: "schema", Name: "TaBle"}, parsed)
require.Equal(t, sqlconnect.RelationRef{Catalog: "cata`log", Schema: "schema", Name: "table"}, parsed)
})
}
6 changes: 4 additions & 2 deletions sqlconnect/internal/databricks/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func TestDatabricksDB(t *testing.T) {
[]byte(configJSON),
strings.ToLower,
integrationtest.Options{
LegacySupport: true,
LegacySupport: true,
SpecialCharactersInQuotedTable: "`-",
},
)
})
Expand All @@ -63,7 +64,8 @@ func TestDatabricksDB(t *testing.T) {
[]byte(configJSON),
strings.ToLower,
integrationtest.Options{
LegacySupport: true,
LegacySupport: true,
SpecialCharactersInQuotedTable: "_A", // No special characters allowed
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Options struct {

IncludesViewsInListTables bool

SpecialCharactersInQuotedTable string // special characters to test in quoted table identifiers (default: <space>,",',`")

ExtraTests func(t *testing.T, db sqlconnect.DB)
}

Expand Down Expand Up @@ -144,17 +146,58 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe
})

t.Run("dialect", func(t *testing.T) {
// Create an unquoted table
unquotedTable := "UnQuoted_TablE"
identifier := db.QuoteIdentifier(schema.Name) + "." + unquotedTable
_, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)")
require.NoError(t, err, "it should be able to create an unquoted table")

table, err := db.ParseRelationRef(identifier)
require.NoError(t, err, "it should be able to parse an unquoted table")
exists, err := db.TableExists(ctx, table)
require.NoError(t, err, "it should be able to check if a table exists")
require.True(t, exists, "it should return true for a table that exists")
t.Run("with unquoted table", func(t *testing.T) {
identifier := db.QuoteIdentifier(schema.Name) + "." + "UnQuoted_TablE"
_, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)")
require.NoError(t, err, "it should be able to create an unquoted table")

table, err := db.ParseRelationRef(identifier)
require.NoError(t, err, "it should be able to parse an unquoted table")

alltables, err := db.ListTables(ctx, schema)
require.NoError(t, err, "it should be able to list tables")

exists, err := db.TableExists(ctx, table)
require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables)
require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables)
})

t.Run("with quoted table", func(t *testing.T) {
identifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("Quoted_TablE")
_, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)")
require.NoErrorf(t, err, "it should be able to create a quoted table: %s", identifier)

table, err := db.ParseRelationRef(identifier)
require.NoError(t, err, "it should be able to parse a quoted table")

alltables, err := db.ListTables(ctx, schema)
require.NoError(t, err, "it should be able to list tables")

exists, err := db.TableExists(ctx, table)
require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables)
require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables)
})

t.Run("with quoted table and special characters", func(t *testing.T) {
specialCharacters := " \"`'"
if len(opts.SpecialCharactersInQuotedTable) > 0 {
specialCharacters = opts.SpecialCharactersInQuotedTable
}

identifier := db.QuoteIdentifier(schema.Name) + "." + db.QuoteIdentifier("Quoted_TablE"+specialCharacters)
_, err := db.Exec("CREATE TABLE " + identifier + " (c1 int)")
require.NoErrorf(t, err, "it should be able to create a quoted table: %s", identifier)

table, err := db.ParseRelationRef(identifier)
require.NoError(t, err, "it should be able to parse a quoted table")

alltables, err := db.ListTables(ctx, schema)
require.NoError(t, err, "it should be able to list tables")

exists, err := db.TableExists(ctx, table)
require.NoErrorf(t, err, "it should be able to check if a table exists: %s allTables: %+v", table, alltables)
require.Truef(t, exists, "it should return true for a table that exists: %s allTables: %+v", table, alltables)
})
})

t.Run("table admin", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion sqlconnect/internal/mysql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (d dialect) QuoteTable(table sqlconnect.RelationRef) string {

// QuoteIdentifier quotes an identifier, e.g. a column name
func (d dialect) QuoteIdentifier(name string) string {
return "`" + name + "`"
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
}

// FormatTableName formats a table name, typically by lower or upper casing it, depending on the database
Expand Down
Loading

0 comments on commit 7a11657

Please sign in to comment.