From eccbbffdc30c65853c35f648fe3d442971049d9a Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Fri, 16 Feb 2024 11:16:01 +0200 Subject: [PATCH] fixup! feat: sqlconnect library --- .github/workflows/test.yaml | 2 +- Makefile | 35 +--- sqlconnect/async.go | 1 + sqlconnect/{ref_column.go => columnref.go} | 0 sqlconnect/db.go | 2 + sqlconnect/internal/base/db.go | 35 +++- sqlconnect/internal/base/dbopts.go | 17 +- sqlconnect/internal/base/dialect_test.go | 30 ++++ sqlconnect/internal/base/mapper.go | 6 +- sqlconnect/internal/base/schemaadmin.go | 16 +- sqlconnect/internal/base/tableadmin.go | 30 ++-- sqlconnect/internal/bigquery/db.go | 31 +++- sqlconnect/internal/bigquery/dialect_test.go | 30 ++++ .../internal/bigquery/driver/columns.go | 41 ++--- sqlconnect/internal/bigquery/driver/rows.go | 5 + .../internal/bigquery/integration_test.go | 18 ++ sqlconnect/internal/bigquery/mappings.go | 85 ++++++--- sqlconnect/internal/bigquery/schemaadmin.go | 57 ++++++ .../testdata/column-mapping-test-columns.json | 74 ++++++++ .../testdata/column-mapping-test-rows.json | 62 +++++++ .../testdata/column-mapping-test-seed.sql | 27 +++ sqlconnect/internal/databricks/config.go | 1 + sqlconnect/internal/databricks/db.go | 17 +- .../internal/databricks/dialect_test.go | 30 ++++ .../internal/databricks/integration_test.go | 30 ++++ .../db_integration_test_scenario.go | 169 +++++++++++++++++- sqlconnect/internal/mysql/db.go | 4 + sqlconnect/internal/mysql/dialect_test.go | 30 ++++ sqlconnect/internal/mysql/integration_test.go | 3 +- sqlconnect/internal/mysql/mappings.go | 6 + .../internal/postgres/integration_test.go | 3 +- sqlconnect/internal/redshift/db.go | 5 +- .../internal/redshift/integration_test.go | 19 ++ sqlconnect/internal/snowflake/db.go | 11 +- sqlconnect/internal/snowflake/dialect_test.go | 30 ++++ .../internal/snowflake/integration_test.go | 19 ++ sqlconnect/internal/snowflake/mappings.go | 21 ++- sqlconnect/internal/trino/db.go | 8 +- sqlconnect/internal/trino/integration_test.go | 19 ++ sqlconnect/internal/trino/mappings.go | 16 +- sqlconnect/internal/util/validatehost.go | 1 + sqlconnect/internal/util/validatehost_test.go | 26 +++ sqlconnect/{def_query.go => querydef.go} | 0 sqlconnect/querydef_test.go | 64 +++++++ .../{ref_relation.go => relationref.go} | 0 ...ef_relationopts.go => relationref_opts.go} | 0 sqlconnect/relationref_test.go | 60 +++++++ sqlconnect/{ref_schema.go => schemaref.go} | 0 sqlconnect/schemaref_test.go | 16 ++ 49 files changed, 1081 insertions(+), 131 deletions(-) rename sqlconnect/{ref_column.go => columnref.go} (100%) create mode 100644 sqlconnect/internal/base/dialect_test.go create mode 100644 sqlconnect/internal/bigquery/dialect_test.go create mode 100644 sqlconnect/internal/bigquery/integration_test.go create mode 100644 sqlconnect/internal/bigquery/schemaadmin.go create mode 100644 sqlconnect/internal/bigquery/testdata/column-mapping-test-columns.json create mode 100644 sqlconnect/internal/bigquery/testdata/column-mapping-test-rows.json create mode 100644 sqlconnect/internal/bigquery/testdata/column-mapping-test-seed.sql create mode 100644 sqlconnect/internal/databricks/dialect_test.go create mode 100644 sqlconnect/internal/databricks/integration_test.go create mode 100644 sqlconnect/internal/mysql/dialect_test.go create mode 100644 sqlconnect/internal/redshift/integration_test.go create mode 100644 sqlconnect/internal/snowflake/dialect_test.go create mode 100644 sqlconnect/internal/snowflake/integration_test.go create mode 100644 sqlconnect/internal/trino/integration_test.go create mode 100644 sqlconnect/internal/util/validatehost_test.go rename sqlconnect/{def_query.go => querydef.go} (100%) create mode 100644 sqlconnect/querydef_test.go rename sqlconnect/{ref_relation.go => relationref.go} (100%) rename sqlconnect/{ref_relationopts.go => relationref_opts.go} (100%) create mode 100644 sqlconnect/relationref_test.go rename sqlconnect/{ref_schema.go => schemaref.go} (100%) create mode 100644 sqlconnect/schemaref_test.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c1de80c..6236071 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -82,7 +82,7 @@ jobs: run: | go install github.com/wadey/gocovmerge@latest gocovmerge */profile.out > profile.out - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: fail_ci_if_error: true files: ./profile.out diff --git a/Makefile b/Makefile index 19e20bf..7b74e26 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help default test test-run test-teardown generate lint fmt +.PHONY: help default test test-run generate lint fmt GO=go LDFLAGS?=-s -w @@ -9,47 +9,30 @@ default: lint generate: install-tools $(GO) generate ./... -test: install-tools test-run test-teardown +test: install-tools test-run test-run: ## Run all unit tests ifeq ($(filter 1,$(debug) $(RUNNER_DEBUG)),) - $(eval TEST_CMD = SLOW=0 gotestsum --format pkgname-and-test-fails --) - $(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=count -coverpkg=./... -vet=all --timeout=15m) + $(eval TEST_CMD = gotestsum --format pkgname-and-test-fails --) + $(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=atomic -coverpkg=./... -vet=all --timeout=30m) else $(eval TEST_CMD = SLOW=0 go test) - $(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=count -coverpkg=./... -vet=all --timeout=15m) + $(eval TEST_OPTIONS = -p=1 -v -failfast -shuffle=on -coverprofile=profile.out -covermode=atomic -coverpkg=./... -vet=all --timeout=30m) endif ifdef package ifdef exclude $(eval FILES = `go list ./$(package)/... | egrep -iv '$(exclude)'`) - $(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) || true + $(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) else - $(TEST_CMD) $(TEST_OPTIONS) ./$(package)/... && touch $(TESTFILE) || true + $(TEST_CMD) $(TEST_OPTIONS) ./$(package)/... && touch $(TESTFILE) endif else ifdef exclude $(eval FILES = `go list ./... | egrep -iv '$(exclude)'`) - $(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) || true + $(TEST_CMD) -count=1 $(TEST_OPTIONS) $(FILES) && touch $(TESTFILE) else - $(TEST_CMD) -count=1 $(TEST_OPTIONS) ./... && touch $(TESTFILE) || true + $(TEST_CMD) -count=1 $(TEST_OPTIONS) ./... && touch $(TESTFILE) endif -test-teardown: - @if [ -f "$(TESTFILE)" ]; then \ - echo "Tests passed, tearing down..." ;\ - rm -f $(TESTFILE) ;\ - echo "mode: atomic" > coverage.txt ;\ - find . -name "profile.out" | while read file; do grep -v 'mode: atomic' $${file} >> coverage.txt; rm -f $${file}; done ;\ - else \ - rm -f coverage.txt coverage.html ; find . -name "profile.out" | xargs rm -f ;\ - echo "Tests failed :-(" ;\ - exit 1 ;\ - fi - -coverage: - go tool cover -html=coverage.txt -o coverage.html - -test-with-coverage: test coverage - help: ## Show the available commands @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' ./Makefile | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/sqlconnect/async.go b/sqlconnect/async.go index 2f3270d..51187c9 100644 --- a/sqlconnect/async.go +++ b/sqlconnect/async.go @@ -21,6 +21,7 @@ func QueryAsync[T any](ctx context.Context, db DB, mapper RowMapper[T], query st s := &async.Sender[ValueOrError[T]]{} ctx, ch, leave = s.Begin(ctx) go func() { + defer s.Close() rows, err := db.QueryContext(ctx, query, params...) if err != nil { s.Send(ValueOrError[T]{Err: fmt.Errorf("executing query: %w", err)}) diff --git a/sqlconnect/ref_column.go b/sqlconnect/columnref.go similarity index 100% rename from sqlconnect/ref_column.go rename to sqlconnect/columnref.go diff --git a/sqlconnect/db.go b/sqlconnect/db.go index adf6f38..3aa7623 100644 --- a/sqlconnect/db.go +++ b/sqlconnect/db.go @@ -90,3 +90,5 @@ type Dialect interface { // FormatTableName formats a table name, typically by lower or upper casing it, depending on the database FormatTableName(name string) string } + +// var ErrNotSupported = errors.New("sqlconnect: operation not supported") diff --git a/sqlconnect/internal/base/db.go b/sqlconnect/internal/base/db.go index 04542d2..9e25697 100644 --- a/sqlconnect/internal/base/db.go +++ b/sqlconnect/internal/base/db.go @@ -13,8 +13,8 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB { d := &DB{ DB: db, Dialect: dialect{}, - columnTypeMapper: func(databaseTypeName string) string { - return databaseTypeName + columnTypeMapper: func(c ColumnType) string { + return c.DatabaseTypeName() }, jsonRowMapper: func(databaseTypeName string, value any) any { return value @@ -26,7 +26,7 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB { return "SELECT schema_name FROM information_schema.schemata", "schema_name" }, SchemaExists: func(schema string) string { - return fmt.Sprintf("SELECT EXISTS (SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s')", schema) + return fmt.Sprintf("SELECT schema_name FROM information_schema.schemata where schema_name = '%[1]s'", schema) }, DropSchema: func(schema string) string { return fmt.Sprintf("DROP SCHEMA %[1]s CASCADE", schema) }, CreateTestTable: func(table string) string { @@ -34,7 +34,7 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB { }, ListTables: func(schema string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = %[1]s", schema), B: "table_name"}, + {A: fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%[1]s'", schema), B: "table_name"}, } }, ListTablesWithPrefix: func(schema, prefix string) []lo.Tuple2[string, string] { @@ -43,7 +43,7 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB { } }, TableExists: func(schema, table string) string { - return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[1]s'", schema, table) + return fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema='%[1]s' and table_name = '%[2]s'", schema, table) }, ListColumns: func(schema, table string) (string, string, string) { return fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%[1]s' AND table_name = '%[2]s'", schema, table), "column_name", "data_type" @@ -51,8 +51,8 @@ func NewDB(db *sql.DB, rudderSchema string, opts ...Option) *DB { CountTableRows: func(table string) string { return fmt.Sprintf("SELECT COUNT(*) FROM %[1]s", table) }, DropTable: func(table string) string { return fmt.Sprintf("DROP TABLE IF EXISTS %[1]s", table) }, TruncateTable: func(table string) string { return fmt.Sprintf("TRUNCATE TABLE %[1]s", table) }, - RenameTable: func(oldName, newName string) string { - return fmt.Sprintf("ALTER TABLE %[1]s RENAME TO %[2]s", oldName, newName) + RenameTable: func(schema, oldName, newName string) string { + return fmt.Sprintf("ALTER TABLE %[1]s.%[2]s RENAME TO %[3]s", schema, oldName, newName) }, }, } @@ -67,11 +67,28 @@ type DB struct { sqlconnect.Dialect rudderSchema string - columnTypeMapper func(string) string // map from database type to rudder type + columnTypeMapper func(ColumnType) string // map from database type to rudder type jsonRowMapper func(databaseTypeName string, value any) any sqlCommands SQLCommands } +type ColumnType interface { + DatabaseTypeName() string + DecimalSize() (precision, scale int64, ok bool) +} + +type colRefTypeAdapter struct { + sqlconnect.ColumnRef +} + +func (c colRefTypeAdapter) DatabaseTypeName() string { + return c.Type +} + +func (c colRefTypeAdapter) DecimalSize() (precision, scale int64, ok bool) { + return 0, 0, false +} + // SqlDB returns the underlying *sql.DB func (db *DB) SqlDB() *sql.DB { return db.DB @@ -103,5 +120,5 @@ type SQLCommands struct { // Provides the SQL command to truncate a table TruncateTable func(table string) string // Provides the SQL command to rename a table - RenameTable func(oldName, newName string) string + RenameTable func(schema, oldName, newName string) string } diff --git a/sqlconnect/internal/base/dbopts.go b/sqlconnect/internal/base/dbopts.go index 48d63cf..3614b36 100644 --- a/sqlconnect/internal/base/dbopts.go +++ b/sqlconnect/internal/base/dbopts.go @@ -1,23 +1,30 @@ package base -import "github.com/rudderlabs/sqlconnect-go/sqlconnect" +import ( + "strings" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) type Option func(*DB) // WithColumnTypeMappings sets the column type mappings for the client func WithColumnTypeMappings(columnTypeMappings map[string]string) Option { return func(db *DB) { - db.columnTypeMapper = func(dbType string) string { - if mappedType, ok := columnTypeMappings[dbType]; ok { + db.columnTypeMapper = func(c ColumnType) string { + if mappedType, ok := columnTypeMappings[strings.ToLower(c.DatabaseTypeName())]; ok { + return mappedType + } + if mappedType, ok := columnTypeMappings[strings.ToUpper(c.DatabaseTypeName())]; ok { return mappedType } - return dbType + return c.DatabaseTypeName() } } } // WithColumnTypeMapper sets the column type mapper for the client -func WithColumnTypeMapper(columnTypeMapper func(string) string) Option { +func WithColumnTypeMapper(columnTypeMapper func(ColumnType) string) Option { return func(db *DB) { db.columnTypeMapper = columnTypeMapper } diff --git a/sqlconnect/internal/base/dialect_test.go b/sqlconnect/internal/base/dialect_test.go new file mode 100644 index 0000000..c9a97af --- /dev/null +++ b/sqlconnect/internal/base/dialect_test.go @@ -0,0 +1,30 @@ +package base + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, `"column"`, quoted, "column name should be quoted with double quotes") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, `"table"`, quoted, "table name should be quoted with double quotes") + + quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"))) + require.Equal(t, `"schema"."table"`, quoted, "schema and table name should be quoted with double quotes") + }) +} diff --git a/sqlconnect/internal/base/mapper.go b/sqlconnect/internal/base/mapper.go index 6c03360..d88f139 100644 --- a/sqlconnect/internal/base/mapper.go +++ b/sqlconnect/internal/base/mapper.go @@ -22,8 +22,12 @@ func (db *DB) JSONRowMapper() sqlconnect.RowMapper[json.RawMessage] { o := map[string]any{} for i := range values { v := values[i].(*NilAny) + var val any + if v != nil { + val = v.Value + } col := cols[i] - o[col.Name()] = db.jsonRowMapper(col.DatabaseTypeName(), v) + o[col.Name()] = db.jsonRowMapper(col.DatabaseTypeName(), val) } b, err := json.Marshal(o) if err != nil { diff --git a/sqlconnect/internal/base/schemaadmin.go b/sqlconnect/internal/base/schemaadmin.go index 91432df..33ac3b2 100644 --- a/sqlconnect/internal/base/schemaadmin.go +++ b/sqlconnect/internal/base/schemaadmin.go @@ -3,6 +3,7 @@ package base import ( "context" "fmt" + "strings" "github.com/samber/lo" @@ -36,20 +37,22 @@ func (db *DB) ListSchemas(ctx context.Context) ([]sqlconnect.SchemaRef, error) { if err != nil { return nil, fmt.Errorf("getting columns in list schemas: %w", err) } + cols = lo.Map(cols, func(col string, _ int) string { return strings.ToLower(col) }) var schema sqlconnect.SchemaRef scanValues := make([]any, len(cols)) if len(cols) == 1 { scanValues[0] = &schema.Name } else { - tableNameColIdx := lo.IndexOf(cols, colName) + tableNameColIdx := lo.IndexOf(cols, strings.ToLower(colName)) if tableNameColIdx == -1 { return nil, fmt.Errorf("column %s not found in result set: %+v", colName, cols) } + var otherCol NilAny for i := 0; i < len(cols); i++ { if i == tableNameColIdx { scanValues[i] = &schema.Name } else { - scanValues[i] = new(NilAny) + scanValues[i] = &otherCol } } } @@ -68,10 +71,15 @@ func (db *DB) ListSchemas(ctx context.Context) ([]sqlconnect.SchemaRef, error) { // SchemaExists returns true if the schema exists func (db *DB) SchemaExists(ctx context.Context, schemaRef sqlconnect.SchemaRef) (bool, error) { - var exists bool - if err := db.QueryRowContext(ctx, db.sqlCommands.SchemaExists(schemaRef.Name)).Scan(&exists); err != nil { + rows, err := db.QueryContext(ctx, db.sqlCommands.SchemaExists(schemaRef.Name)) + if err != nil { return false, fmt.Errorf("querying schema exists: %w", err) } + defer func() { _ = rows.Close() }() + exists := rows.Next() + if err := rows.Err(); err != nil { + return false, fmt.Errorf("iterating schema exists: %w", err) + } return exists, nil } diff --git a/sqlconnect/internal/base/tableadmin.go b/sqlconnect/internal/base/tableadmin.go index cc17fb7..153ded5 100644 --- a/sqlconnect/internal/base/tableadmin.go +++ b/sqlconnect/internal/base/tableadmin.go @@ -3,6 +3,7 @@ package base import ( "context" "fmt" + "strings" "github.com/samber/lo" @@ -30,20 +31,22 @@ func (db *DB) ListTables(ctx context.Context, schema sqlconnect.SchemaRef) ([]sq if err != nil { return nil, fmt.Errorf("getting columns in list tables for schema %s: %w", schema, err) } + cols = lo.Map(cols, func(col string, _ int) string { return strings.ToLower(col) }) var name string scanValues := make([]any, len(cols)) if len(cols) == 1 { scanValues[0] = &name } else { - tableNameColIdx := lo.IndexOf(cols, colName) + tableNameColIdx := lo.IndexOf(cols, strings.ToLower(colName)) if tableNameColIdx == -1 { return nil, fmt.Errorf("column %s not found in result set: %+v", colName, cols) } + var otherCol NilAny for i := 0; i < len(cols); i++ { if i == tableNameColIdx { scanValues[i] = &name } else { - scanValues[i] = new(NilAny) + scanValues[i] = &otherCol } } } @@ -77,20 +80,22 @@ func (db *DB) ListTablesWithPrefix(ctx context.Context, schema sqlconnect.Schema if err != nil { return nil, fmt.Errorf("getting columns in list tables for schema %s with prefix %s: %w", schema, prefix, err) } + cols = lo.Map(cols, func(col string, _ int) string { return strings.ToLower(col) }) var name string scanValues := make([]any, len(cols)) if len(cols) == 1 { scanValues[0] = &name } else { - tableNameColIdx := lo.IndexOf(cols, colName) + tableNameColIdx := lo.IndexOf(cols, strings.ToLower(colName)) if tableNameColIdx == -1 { return nil, fmt.Errorf("column %s not found in result set: %+v", colName, cols) } + var otherCol NilAny for i := 0; i < len(cols); i++ { if i == tableNameColIdx { scanValues[i] = &name } else { - scanValues[i] = new(NilAny) + scanValues[i] = &otherCol } } } @@ -110,7 +115,8 @@ func (db *DB) ListTablesWithPrefix(ctx context.Context, schema sqlconnect.Schema // TableExists returns true if the table exists func (db *DB) TableExists(ctx context.Context, relation sqlconnect.RelationRef) (bool, error) { - rows, err := db.QueryContext(ctx, db.sqlCommands.TableExists(relation.Schema, relation.Name)) + stmt := db.sqlCommands.TableExists(relation.Schema, relation.Name) + rows, err := db.QueryContext(ctx, stmt) if err != nil { return false, fmt.Errorf("querying table %s exists: %w", relation, err) } @@ -137,24 +143,26 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) if err != nil { return nil, fmt.Errorf("getting columns in list columns for %s: %w", relation.String(), err) } + cols = lo.Map(cols, func(col string, _ int) string { return strings.ToLower(col) }) var column sqlconnect.ColumnRef scanValues := make([]any, len(cols)) - nameColIdx := lo.IndexOf(cols, nameCol) + nameColIdx := lo.IndexOf(cols, strings.ToLower(nameCol)) if nameColIdx == -1 { return nil, fmt.Errorf("column %s not found in result set: %+v", nameCol, cols) } - typeColIdx := lo.IndexOf(cols, typeCol) + typeColIdx := lo.IndexOf(cols, strings.ToLower(typeCol)) if typeColIdx == -1 { return nil, fmt.Errorf("column %s not found in result set: %+v", typeCol, cols) } + var otherCol NilAny for i := 0; i < len(cols); i++ { if i == nameColIdx { scanValues[i] = &column.Name } else if i == typeColIdx { scanValues[i] = &column.Type } else { - scanValues[i] = new(NilAny) + scanValues[i] = &otherCol } } @@ -162,7 +170,7 @@ func (db *DB) ListColumns(ctx context.Context, relation sqlconnect.RelationRef) if err := columns.Scan(scanValues...); err != nil { return nil, fmt.Errorf("scanning list columns for %s: %w", relation.String(), err) } - column.Type = db.columnTypeMapper(column.Type) + column.Type = db.columnTypeMapper(colRefTypeAdapter{column}) res = append(res, column) } @@ -188,7 +196,7 @@ func (db *DB) ListColumnsForSqlQuery(ctx context.Context, sql string) ([]sqlconn for _, col := range colTypes { res = append(res, sqlconnect.ColumnRef{ Name: col.Name(), - Type: db.columnTypeMapper(col.DatabaseTypeName()), + Type: db.columnTypeMapper(col), }) } return res, nil @@ -224,7 +232,7 @@ func (db *DB) RenameTable(ctx context.Context, oldRef, newRef sqlconnect.Relatio if oldRef.Schema != newRef.Schema { return fmt.Errorf("moving table to another schema not supported, oldRef: %s newRef: %s", oldRef, newRef) } - if _, err := db.ExecContext(ctx, db.sqlCommands.RenameTable(db.QuoteTable(oldRef), db.QuoteIdentifier(newRef.Name))); err != nil { + if _, err := db.ExecContext(ctx, db.sqlCommands.RenameTable(db.QuoteIdentifier(oldRef.Schema), db.QuoteIdentifier(oldRef.Name), db.QuoteIdentifier(newRef.Name))); err != nil { return fmt.Errorf("renaming table %s to %s: %w", oldRef.String(), newRef.String(), err) } return nil diff --git a/sqlconnect/internal/bigquery/db.go b/sqlconnect/internal/bigquery/db.go index 2d585b5..c148b47 100644 --- a/sqlconnect/internal/bigquery/db.go +++ b/sqlconnect/internal/bigquery/db.go @@ -1,10 +1,12 @@ package bigquery import ( + "context" "database/sql" "encoding/json" "fmt" + "cloud.google.com/go/bigquery" "github.com/samber/lo" "google.golang.org/api/option" @@ -32,11 +34,11 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { db, lo.Ternary(config.RudderSchema != "", config.RudderSchema, defaultRudderSchema), base.WithDialect(dialect{}), - base.WithColumnTypeMappings(columnTypeMappings), + base.WithColumnTypeMapper(columnTypeMapper), base.WithJsonRowMapper(jsonRowMapper), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { - cmds.ListColumns = func(schema, table string) (string, string, string) { - return fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table), "column_name", "data_type" + cmds.CreateTestTable = func(table string) string { + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 STRING)", table) } cmds.ListTables = func(schema string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ @@ -49,7 +51,10 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { } } cmds.TableExists = func(schema, table string) string { - return fmt.Sprintf("SELECT EXISTS (SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[1]s'", schema, table) + return fmt.Sprintf("SELECT table_name FROM `%[1]s`.INFORMATION_SCHEMA.TABLES WHERE table_name = '%[2]s'", schema, table) + } + cmds.ListColumns = func(schema, table string) (string, string, string) { + return fmt.Sprintf("SELECT column_name, data_type FROM `%[1]s`.INFORMATION_SCHEMA.COLUMNS WHERE table_name = '%[2]s'", schema, table), "column_name", "data_type" } return cmds @@ -67,3 +72,21 @@ func init() { type DB struct { *base.DB } + +func (db *DB) WithBigqueryClient(ctx context.Context, f func(*bigquery.Client) error) error { + sqlconn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { _ = sqlconn.Close() }() + return sqlconn.Raw(func(driverConn any) error { + if c, ok := driverConn.(bqclient); ok { + return f(c.BigqueryClient()) + } + return fmt.Errorf("invalid driver connection") + }) +} + +type bqclient interface { + BigqueryClient() *bigquery.Client +} diff --git a/sqlconnect/internal/bigquery/dialect_test.go b/sqlconnect/internal/bigquery/dialect_test.go new file mode 100644 index 0000000..3e245e9 --- /dev/null +++ b/sqlconnect/internal/bigquery/dialect_test.go @@ -0,0 +1,30 @@ +package bigquery + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, "`column`", quoted, "column name should be quoted with backticks") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, "`table`", quoted, "table name should be quoted with backticks") + + quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"))) + require.Equal(t, "`schema.table`", quoted, "schema and table name should be quoted with backticks") + }) +} diff --git a/sqlconnect/internal/bigquery/driver/columns.go b/sqlconnect/internal/bigquery/driver/columns.go index ee5c31c..d6e8c0c 100644 --- a/sqlconnect/internal/bigquery/driver/columns.go +++ b/sqlconnect/internal/bigquery/driver/columns.go @@ -2,7 +2,6 @@ package driver import ( "database/sql/driver" - "encoding/json" "cloud.google.com/go/bigquery" ) @@ -10,6 +9,7 @@ import ( type bigQuerySchema interface { ColumnNames() []string ConvertColumnValue(index int, value bigquery.Value) (driver.Value, error) + ColumnTypeDatabaseTypeName(index int) string } type bigQueryColumns struct { @@ -30,37 +30,24 @@ func (columns bigQueryColumns) ColumnNames() []string { return columns.names } -type bigQueryReroutedColumn struct { - values []bigquery.Value - schema bigquery.Schema -} +func (columns bigQueryColumns) ColumnTypeDatabaseTypeName(index int) string { + if index > -1 && len(columns.columns) > index { + column := columns.columns[index] + if column.FieldSchema.Repeated { + return "ARRAY" + } + return string(column.FieldSchema.Type) + } -func (c bigQueryReroutedColumn) MarshalJSON() ([]byte, error) { - return json.Marshal(c.values) + return "" } type bigQueryColumn struct { - Name string - Schema bigquery.Schema + Name string + FieldSchema *bigquery.FieldSchema } func (column bigQueryColumn) ConvertValue(value bigquery.Value) (driver.Value, error) { - if len(column.Schema) == 0 { - return value, nil - } - - values, ok := value.([]bigquery.Value) - if ok { - - if len(values) > 0 { - if _, isRows := values[0].([]bigquery.Value); !isRows { - values = []bigquery.Value{values} - } - } - - value = bigQueryReroutedColumn{values: values, schema: column.Schema} - } - return value, nil } @@ -73,8 +60,8 @@ func createBigQuerySchema(schema bigquery.Schema) bigQuerySchema { names = append(names, name) columns = append(columns, bigQueryColumn{ - Name: name, - Schema: column.Schema, + Name: name, + FieldSchema: column, }) } return &bigQueryColumns{ diff --git a/sqlconnect/internal/bigquery/driver/rows.go b/sqlconnect/internal/bigquery/driver/rows.go index 7f825ff..a50d7c4 100644 --- a/sqlconnect/internal/bigquery/driver/rows.go +++ b/sqlconnect/internal/bigquery/driver/rows.go @@ -51,3 +51,8 @@ func (rows *bigQueryRows) Next(dest []driver.Value) error { return nil } + +func (rows *bigQueryRows) ColumnTypeDatabaseTypeName(index int) string { + rows.ensureSchema() + return rows.schema.ColumnTypeDatabaseTypeName(index) +} diff --git a/sqlconnect/internal/bigquery/integration_test.go b/sqlconnect/internal/bigquery/integration_test.go new file mode 100644 index 0000000..5ec2ada --- /dev/null +++ b/sqlconnect/internal/bigquery/integration_test.go @@ -0,0 +1,18 @@ +package bigquery_test + +import ( + "os" + "strings" + "testing" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/bigquery" + integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" +) + +func TestBigqueryDB(t *testing.T) { + configJSON, ok := os.LookupEnv("BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping bigquery integration test due to lack of a test environment") + } + integrationtest.TestDatabaseScenarios(t, bigquery.DatabaseType, []byte(configJSON), strings.ToLower) +} diff --git a/sqlconnect/internal/bigquery/mappings.go b/sqlconnect/internal/bigquery/mappings.go index bb3b8da..abe6bc7 100644 --- a/sqlconnect/internal/bigquery/mappings.go +++ b/sqlconnect/internal/bigquery/mappings.go @@ -1,33 +1,66 @@ package bigquery -import "math/big" +import ( + "encoding/json" + "math/big" + "regexp" + "strings" + + "cloud.google.com/go/bigquery" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" +) // mapping of database column types to rudder types var columnTypeMappings = map[string]string{ - "BOOLEAN": "boolean", - "BOOL": "boolean", - "INTEGER": "int", - "INT64": "int", - "INT": "int", - "SMALLINT": "int", - "TINYINT": "int", - "BIGINT": "int", - "NUMERIC": "float", - "BIGNUMERIC": "float", - "FLOAT": "float", - "FLOAT64": "float", - "DECIMAL": "float", + "BOOLEAN": "boolean", + "BOOL": "boolean", + + "INT64": "int", // INT64 and aliases + "INT": "int", + "SMALLINT": "int", + "INTEGER": "int", + "BIGINT": "int", + "TINYINT": "int", + "BYTEINT": "int", + + "INTERVAL": "int", + + "NUMERIC": "float", // NUMERIC and aliases + "DECIMAL": "float", + + "BIGNUMERIC": "float", // BIGNUMERIC and aliases "BIGDECIMAL": "float", - "STRING": "string", - "BYTES": "string", - "DATE": "datetime", - "DATETIME": "datetime", - "TIME": "datetime", - "TIMESTAMP": "datetime", + + "FLOAT": "float", + "FLOAT64": "float", + + "STRING": "string", + "BYTES": "string", + "GEOGRAPHY": "string", + + "DATE": "datetime", + "DATETIME": "datetime", + "TIME": "datetime", + "TIMESTAMP": "datetime", + + "JSON": "json", + "ARRAY": "json", + "STRUCT": "json", // STRUCT and RECORD are represented as an array of json objects + "RECORD": "json", +} + +var re = regexp.MustCompile(`(\([^)]+\)|<[^>]+>)`) + +func columnTypeMapper(columnType base.ColumnType) string { + databaseTypeName := strings.ToUpper(re.ReplaceAllString(columnType.DatabaseTypeName(), "")) + if mappedType, ok := columnTypeMappings[strings.ToUpper(databaseTypeName)]; ok { + return mappedType + } + return databaseTypeName } // jsonRowMapper maps a row's scanned column to a json object's field -func jsonRowMapper(_ string, value any) any { +func jsonRowMapper(databaseTypeName string, value any) any { switch v := (value).(type) { case *big.Rat: // Handle big.Rat values @@ -37,6 +70,16 @@ func jsonRowMapper(_ string, value any) any { } else { return v.Num().Int64() } + case *bigquery.IntervalValue: + return v.ToDuration() + case []uint8: + return string(v) + case string: + switch databaseTypeName { + case "JSON": + return json.RawMessage(v) + } + return v default: // Handle other data types as is return v diff --git a/sqlconnect/internal/bigquery/schemaadmin.go b/sqlconnect/internal/bigquery/schemaadmin.go new file mode 100644 index 0000000..b10b2b8 --- /dev/null +++ b/sqlconnect/internal/bigquery/schemaadmin.go @@ -0,0 +1,57 @@ +package bigquery + +import ( + "context" + "errors" + + "cloud.google.com/go/bigquery" + "google.golang.org/api/googleapi" + "google.golang.org/api/iterator" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +// SchemaExists uses the bigquery client instead of [INFORMATION_SCHEMA.SCHEMATA] due to absence of a region qualifier +// https://cloud.google.com/bigquery/docs/information-schema-datasets-schemata#scope_and_syntax +func (db *DB) SchemaExists(ctx context.Context, schemaRef sqlconnect.SchemaRef) (bool, error) { + var exists bool + if err := db.WithBigqueryClient(ctx, func(c *bigquery.Client) error { + if _, err := c.Dataset(schemaRef.Name).Metadata(ctx); err != nil { + var e *googleapi.Error + if ok := errors.As(err, &e); ok { + if e.Code == 404 { // not found + return nil + } + } + return err + } + exists = true + return nil + }); err != nil { + return false, err + } + return exists, nil +} + +// ListSchemas uses the bigquery client instead of [INFORMATION_SCHEMA.SCHEMATA] due to absence of a region qualifier +// https://cloud.google.com/bigquery/docs/information-schema-datasets-schemata#scope_and_syntax +func (db *DB) ListSchemas(ctx context.Context) ([]sqlconnect.SchemaRef, error) { + var schemas []sqlconnect.SchemaRef + if err := db.WithBigqueryClient(ctx, func(c *bigquery.Client) error { + datasets := c.Datasets(ctx) + for { + var dataset *bigquery.Dataset + dataset, err := datasets.Next() + if err != nil { + if err == iterator.Done { + return nil + } + return err + } + schemas = append(schemas, sqlconnect.SchemaRef{Name: dataset.DatasetID}) + } + }); err != nil { + return nil, err + } + return schemas, nil +} diff --git a/sqlconnect/internal/bigquery/testdata/column-mapping-test-columns.json b/sqlconnect/internal/bigquery/testdata/column-mapping-test-columns.json new file mode 100644 index 0000000..e4d3d3f --- /dev/null +++ b/sqlconnect/internal/bigquery/testdata/column-mapping-test-columns.json @@ -0,0 +1,74 @@ +[ + { + "name": "_order", + "type": "int" + }, + { + "name": "_array", + "type": "json" + }, + { + "name": "_bignumeric", + "type": "float" + }, + { + "name": "_bignumericnoscale", + "type": "float" + }, + { + "name": "_bool", + "type": "boolean" + }, + { + "name": "_bytes", + "type": "string" + }, + { + "name": "_date", + "type": "datetime" + }, + { + "name": "_datetime", + "type": "datetime" + }, + { + "name": "_float64", + "type": "float" + }, + { + "name": "_geo", + "type": "string" + }, + { + "name": "_int64", + "type": "int" + }, + { + "name": "_interval", + "type": "int" + }, + { + "name": "_json", + "type": "json" + }, + { + "name": "_numeric", + "type": "float" + }, + { + "name": "_string", + "type": "string" + }, + { + "name": "_struct", + "type": "json" + }, + { + "name": "_time", + "type": "datetime" + }, + { + "name": "_timestamp", + "type": "datetime" + } +] \ No newline at end of file diff --git a/sqlconnect/internal/bigquery/testdata/column-mapping-test-rows.json b/sqlconnect/internal/bigquery/testdata/column-mapping-test-rows.json new file mode 100644 index 0000000..cbe76d1 --- /dev/null +++ b/sqlconnect/internal/bigquery/testdata/column-mapping-test-rows.json @@ -0,0 +1,62 @@ +[ + { + "_order": 1, + "_array": ["ONE"], + "_bignumeric": 1.1, + "_bignumericnoscale": 1, + "_bool": true, + "_bytes": "abc", + "_date": "2014-09-27", + "_datetime": "2014-09-27T12:30:00.450000000", + "_float64": 1.1, + "_geo": "POINT(32 90)", + "_int64": 1, + "_interval": 31104000000000000, + "_json": {"key": "value"}, + "_numeric": 1, + "_string": "string", + "_struct": ["string", 1], + "_time": "12:30:00.450000000", + "_timestamp": "2014-09-27T20:30:00.45Z" + }, + { + "_order": 2, + "_array": null, + "_bignumeric": 0, + "_bignumericnoscale": 0, + "_bool": false, + "_bytes": "", + "_date": "2014-09-27", + "_datetime": "2014-09-27T12:30:00.450000000", + "_float64": 0, + "_geo": "GEOMETRYCOLLECTION EMPTY", + "_int64": 0, + "_interval": 31104000000000000, + "_json": {}, + "_numeric": 0, + "_string": "", + "_struct": ["",0], + "_time": "12:30:00.450000000", + "_timestamp": "2014-09-27T20:30:00.45Z" + }, + { + "_order": 3, + "_array": null, + "_bignumeric": null, + "_bignumericnoscale": null, + "_bool": null, + "_bytes": null, + "_date": null, + "_datetime": null, + "_float64": null, + "_geo": null, + "_int64": null, + "_interval": null, + "_json": null, + "_numeric": null, + "_string": null, + "_struct": null, + "_time": null, + "_timestamp": null + } +] \ No newline at end of file diff --git a/sqlconnect/internal/bigquery/testdata/column-mapping-test-seed.sql b/sqlconnect/internal/bigquery/testdata/column-mapping-test-seed.sql new file mode 100644 index 0000000..27f9532 --- /dev/null +++ b/sqlconnect/internal/bigquery/testdata/column-mapping-test-seed.sql @@ -0,0 +1,27 @@ +CREATE TABLE `{{.schema}}`.`column_mappings_test` ( + _order INT64, + _array ARRAY, + _bignumeric BIGNUMERIC(2,1), + _bignumericnoscale BIGNUMERIC(1,0), + _bool BOOL, + _bytes BYTES, + _date DATE, + _datetime DATETIME, + _float64 FLOAT64, + _geo GEOGRAPHY, + _int64 INT64, + _interval INTERVAL, + _json JSON, + _numeric NUMERIC, + _string STRING(10), + _struct STRUCT, + _time TIME, + _timestamp TIMESTAMP, +); + +INSERT INTO `{{.schema}}`.`column_mappings_test` + (_order, _array, _bignumeric, _bignumericnoscale, _bool, _bytes, _date, _datetime, _float64, _geo, _int64, _interval, _json, _numeric, _string, _struct, _time, _timestamp) +VALUES + (1, ['ONE'], 1.1, 1, TRUE, B"abc", '2014-09-27', '2014-09-27 12:30:00.45', 1.1, ST_GEOGFROMTEXT('POINT(32 90)'), 1, INTERVAL 1 YEAR, JSON '{"key": "value"}', 1, 'string', ('string', 1), '12:30:00.45', '2014-09-27 12:30:00.45-08'), + (2, [], 0.0, 0, FALSE, B"", '2014-09-27', '2014-09-27 12:30:00.45', 0.0, ST_GEOGFROMTEXT('POINT EMPTY'), 0, INTERVAL 1 YEAR, JSON '{}', 0, '', ('', 0), '12:30:00.45', '2014-09-27 12:30:00.45-08'), + (3, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL); \ No newline at end of file diff --git a/sqlconnect/internal/databricks/config.go b/sqlconnect/internal/databricks/config.go index ca7b869..2a61cb7 100644 --- a/sqlconnect/internal/databricks/config.go +++ b/sqlconnect/internal/databricks/config.go @@ -15,6 +15,7 @@ type Config struct { RetryAttempts int `json:"retryAttempts"` MinRetryWaitTime time.Duration `json:"minRetryWaitTime"` MaxRetryWaitTime time.Duration `json:"maxRetryWaitTime"` + MaxConnIdleTime time.Duration `json:"maxConnIdleTime"` // RudderSchema is used to override the default rudder schema name during tests RudderSchema string `json:"rudderSchema"` diff --git a/sqlconnect/internal/databricks/db.go b/sqlconnect/internal/databricks/db.go index 374673c..d685bd5 100644 --- a/sqlconnect/internal/databricks/db.go +++ b/sqlconnect/internal/databricks/db.go @@ -3,6 +3,7 @@ package databricks import ( "database/sql" "encoding/json" + "fmt" databricks "github.com/databricks/databricks-sql-go" "github.com/samber/lo" @@ -45,6 +46,7 @@ func NewDB(configJson json.RawMessage) (*DB, error) { if err != nil { return nil, err } + db.SetConnMaxIdleTime(config.MaxConnIdleTime) return &DB{ DB: base.NewDB( @@ -55,16 +57,27 @@ func NewDB(configJson json.RawMessage) (*DB, error) { base.WithJsonRowMapper(jsonRowMapper), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { cmds.ListSchemas = func() (string, string) { return "SHOW SCHEMAS", "schema_name" } + cmds.SchemaExists = func(schema string) string { return fmt.Sprintf(`SHOW SCHEMAS LIKE '%s'`, schema) } + + cmds.CreateTestTable = func(table string) string { + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %[1]s (c1 INT, c2 STRING)", table) + } cmds.ListTables = func(schema string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: "SHOW TABLES IN " + schema, B: "tableName"}, + {A: fmt.Sprintf("SHOW TABLES IN %s", schema), B: "tableName"}, } } cmds.ListTablesWithPrefix = func(schema, prefix string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: "SHOW TABLES IN " + schema + " LIKE '" + prefix + "*'", B: "tableName"}, + {A: fmt.Sprintf("SHOW TABLES IN %[1]s LIKE '%[2]s'", schema, prefix+"*"), B: "tableName"}, } } + cmds.TableExists = func(schema, table string) string { + return fmt.Sprintf("SHOW TABLES IN %[1]s LIKE '%[2]s'", schema, table) + } + cmds.ListColumns = func(schema, table string) (string, string, string) { + return fmt.Sprintf("DESCRIBE TABLE `%[1]s`.`%[2]s`", schema, table), "col_name", "data_type" + } return cmds }), ), diff --git a/sqlconnect/internal/databricks/dialect_test.go b/sqlconnect/internal/databricks/dialect_test.go new file mode 100644 index 0000000..33f9866 --- /dev/null +++ b/sqlconnect/internal/databricks/dialect_test.go @@ -0,0 +1,30 @@ +package databricks + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, "`column`", quoted, "column name should be quoted with backticks") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, "`table`", quoted, "table name should be quoted with backticks") + + quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"))) + require.Equal(t, "`schema`.`table`", quoted, "schema and table name should be quoted with backticks") + }) +} diff --git a/sqlconnect/internal/databricks/integration_test.go b/sqlconnect/internal/databricks/integration_test.go new file mode 100644 index 0000000..317e255 --- /dev/null +++ b/sqlconnect/internal/databricks/integration_test.go @@ -0,0 +1,30 @@ +package databricks_test + +import ( + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/databricks" + integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" +) + +func TestDatabricksDB(t *testing.T) { + configJSON, ok := os.LookupEnv("DATABRICKS_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping databricks integration test due to lack of a test environment") + } + + configJSON, err := sjson.Set(configJSON, "retryAttempts", 1) + require.NoError(t, err, "failed to set retryAttempts") + configJSON, err = sjson.Set(configJSON, "minRetryWaitTime", time.Second) + require.NoError(t, err, "failed to set minRetryWaitTime") + configJSON, err = sjson.Set(configJSON, "maxRetryWaitTime", time.Minute) + require.NoError(t, err, "failed to set maxRetryWaitTime") + + integrationtest.TestDatabaseScenarios(t, databricks.DatabaseType, []byte(configJSON), strings.ToLower) +} diff --git a/sqlconnect/internal/integration_test/db_integration_test_scenario.go b/sqlconnect/internal/integration_test/db_integration_test_scenario.go index 9a1469d..a9bcfe9 100644 --- a/sqlconnect/internal/integration_test/db_integration_test_scenario.go +++ b/sqlconnect/internal/integration_test/db_integration_test_scenario.go @@ -14,8 +14,8 @@ import ( "github.com/rudderlabs/sqlconnect-go/sqlconnect" ) -func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMessage) { - schema := sqlconnect.SchemaRef{Name: GenerateTestSchema()} +func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMessage, formatfn func(string) string) { + schema := sqlconnect.SchemaRef{Name: GenerateTestSchema(formatfn)} configJSON, err := sjson.SetBytes(configJSON, "rudderSchema", schema.Name) require.NoError(t, err, "it should be able to set the rudder schema") db, err := sqlconnect.NewDB(warehouse, configJSON) @@ -97,7 +97,7 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe }) t.Run("normal operation", func(t *testing.T) { - otherSchema := sqlconnect.SchemaRef{Name: GenerateTestSchema()} + otherSchema := sqlconnect.SchemaRef{Name: GenerateTestSchema(formatfn)} err := db.CreateSchema(ctx, otherSchema) require.NoError(t, err, "it should be able to create a schema") err = db.DropSchema(ctx, otherSchema) @@ -110,8 +110,167 @@ func TestDatabaseScenarios(t *testing.T, warehouse string, configJSON json.RawMe }) }) }) + + t.Run("table admin", func(t *testing.T) { + table := sqlconnect.NewRelationRef(formatfn("test_table"), sqlconnect.WithSchema(schema.Name)) + + t.Run("table doesn't exist", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.TableExists(cancelledCtx, table) + require.Error(t, err, "it should not be able to check if a table exists with a cancelled context") + }) + + exists, err := db.TableExists(ctx, table) + require.NoError(t, err, "it should be able to check if a table exists") + require.False(t, exists, "it should return false for a table that doesn't exist") + }) + + t.Run("create test table", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + err := db.CreateTestTable(cancelledCtx, table) + require.Error(t, err, "it should not be able to create a test table with a cancelled context") + }) + + err := db.CreateTestTable(ctx, table) + require.NoError(t, err, "it should be able to create a test table") + exists, err := db.TableExists(ctx, table) + require.NoError(t, err, "it should be able to check if a table exists") + require.True(t, exists, "it should return true for a table that was just created") + }) + + t.Run("list tables", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.ListTables(cancelledCtx, schema) + require.Error(t, err, "it should not be able to list tables with a cancelled context") + }) + + tables, err := db.ListTables(ctx, schema) + require.NoError(t, err, "it should be able to list tables") + require.Contains(t, tables, table, "it should contain the created table") + }) + + t.Run("list tables with prefix", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.ListTablesWithPrefix(cancelledCtx, schema, formatfn("test")) + require.Error(t, err, "it should not be able to list tables with a prefix with a cancelled context") + }) + + tables, err := db.ListTablesWithPrefix(ctx, schema, formatfn("test")) + require.NoError(t, err, "it should be able to list tables with a prefix") + require.Contains(t, tables, table, "it should contain the created table") + }) + + t.Run("list columns", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.ListColumns(cancelledCtx, table) + require.Error(t, err, "it should not be able to list columns with a cancelled context") + }) + + columns, err := db.ListColumns(ctx, table) + require.NoError(t, err, "it should be able to list columns") + require.Len(t, columns, 2, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: formatfn("c1"), Type: "int"}, + {Name: formatfn("c2"), Type: "string"}, + }, "it should return the correct columns") + }) + + t.Run("list columns for sql query", func(t *testing.T) { + q := sqlconnect.QueryDef{ + Table: &table, + Columns: []string{formatfn("c1")}, + } + stmt := q.ToSQL(db) + + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.ListColumnsForSqlQuery(cancelledCtx, stmt) + require.Error(t, err, "it should not be able to list columns for a sql query with a cancelled context") + }) + + columns, err := db.ListColumnsForSqlQuery(ctx, stmt) + require.NoError(t, err, "it should be able to list columns for a sql query") + require.Len(t, columns, 1, "it should return the correct number of columns") + require.ElementsMatch(t, columns, []sqlconnect.ColumnRef{ + {Name: formatfn("c1"), Type: "int"}, + }, "it should return the correct columns") + }) + + t.Run("count table rows", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + _, err := db.CountTableRows(cancelledCtx, table) + require.Error(t, err, "it should not be able to count table rows with a cancelled context") + }) + + count, err := db.CountTableRows(ctx, table) + require.NoError(t, err, "it should be able to count table rows") + require.Equal(t, 0, count, "it should return 0 for a table with no rows") + + // add a row + _, err = db.ExecContext(ctx, fmt.Sprintf("INSERT INTO %s (c1, c2) VALUES (1, '1')", db.QuoteTable(table))) + require.NoError(t, err, "it should be able to insert a row") + + count, err = db.CountTableRows(ctx, table) + require.NoError(t, err, "it should be able to count table rows") + require.Equal(t, 1, count, "it should return 1 for a table with one row") + }) + + t.Run("truncate table", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + err := db.TruncateTable(cancelledCtx, table) + require.Error(t, err, "it should not be able to truncate a table with a cancelled context") + }) + + err := db.TruncateTable(ctx, table) + require.NoError(t, err, "it should be able to truncate a table") + count, err := db.CountTableRows(ctx, table) + require.NoError(t, err, "it should be able to count table rows") + require.Equal(t, 0, count, "it should return 0 for a table with no rows") + }) + + t.Run("rename table", func(t *testing.T) { + newTable := sqlconnect.NewRelationRef(formatfn("test_table_renamed"), sqlconnect.WithSchema(schema.Name)) + + t.Run("with context cancelled", func(t *testing.T) { + err := db.RenameTable(cancelledCtx, table, newTable) + require.Error(t, err, "it should not be able to rename a table with a cancelled context") + }) + + t.Run("using different schemas", func(t *testing.T) { + newTableWithDifferentSchema := newTable + newTableWithDifferentSchema.Schema = newTableWithDifferentSchema.Schema + "_other" + err := db.RenameTable(ctx, table, newTableWithDifferentSchema) + require.Error(t, err, "it should not be able to rename a table to a different schema") + }) + + t.Run("normal operation", func(t *testing.T) { + err := db.RenameTable(ctx, table, newTable) + require.NoError(t, err, "it should be able to rename a table") + + exists, err := db.TableExists(ctx, newTable) + require.NoError(t, err, "it should be able to check if a table exists") + require.True(t, exists, "it should return true for a table that was just renamed") + + exists, err = db.TableExists(ctx, table) + require.NoError(t, err, "it should be able to check if the old table exists") + require.False(t, exists, "it should return false for the old table which was just renamed") + }) + }) + + t.Run("drop table", func(t *testing.T) { + t.Run("with context cancelled", func(t *testing.T) { + err := db.DropTable(cancelledCtx, table) + require.Error(t, err, "it should not be able to drop a table with a cancelled context") + }) + + err := db.DropTable(ctx, table) + require.NoError(t, err, "it should be able to drop a table") + exists, err := db.TableExists(ctx, table) + require.NoError(t, err, "it should be able to check if a table exists") + require.False(t, exists, "it should return false for a table that was just dropped") + }) + }) } -func GenerateTestSchema() string { - return fmt.Sprintf("tsqlcon_%s_%d", rand.String(12), time.Now().Unix()) +func GenerateTestSchema(formatfn func(string) string) string { + return formatfn(fmt.Sprintf("tsqlcon_%s_%d", rand.String(12), time.Now().Unix())) } diff --git a/sqlconnect/internal/mysql/db.go b/sqlconnect/internal/mysql/db.go index 2721c49..d8d4c8e 100644 --- a/sqlconnect/internal/mysql/db.go +++ b/sqlconnect/internal/mysql/db.go @@ -37,11 +37,15 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { DB: base.NewDB( db, lo.Ternary(config.RudderSchema != "", config.RudderSchema, defaultRudderSchema), + base.WithColumnTypeMappings(columnTypeMappings), base.WithJsonRowMapper(jsonRowMapper), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { cmds.DropSchema = func(schema string) string { // mysql does not support CASCADE return fmt.Sprintf("DROP SCHEMA %[1]s", schema) } + cmds.RenameTable = func(schema, oldName, newName string) string { + return fmt.Sprintf("RENAME TABLE %[1]s.%[2]s TO %[1]s.%[3]s", schema, oldName, newName) + } return cmds }), base.WithDialect(dialect{}), diff --git a/sqlconnect/internal/mysql/dialect_test.go b/sqlconnect/internal/mysql/dialect_test.go new file mode 100644 index 0000000..28ffc84 --- /dev/null +++ b/sqlconnect/internal/mysql/dialect_test.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "table", formatted, "table name should be lowercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, "`column`", quoted, "column name should be quoted with backticks") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, "`table`", quoted, "table name should be quoted with backticks") + + quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"))) + require.Equal(t, "`schema`.`table`", quoted, "schema and table name should be quoted with backticks") + }) +} diff --git a/sqlconnect/internal/mysql/integration_test.go b/sqlconnect/internal/mysql/integration_test.go index 770987e..066d12f 100644 --- a/sqlconnect/internal/mysql/integration_test.go +++ b/sqlconnect/internal/mysql/integration_test.go @@ -3,6 +3,7 @@ package mysql_test import ( "encoding/json" "strconv" + "strings" "testing" "github.com/ory/dockertest/v3" @@ -33,5 +34,5 @@ func TestMysqlDB(t *testing.T) { configJSON, err := json.Marshal(config) require.NoError(t, err, "it should be able to marshal config to json") - integrationtest.TestDatabaseScenarios(t, mysql.DatabaseType, configJSON) + integrationtest.TestDatabaseScenarios(t, mysql.DatabaseType, configJSON, strings.ToLower) } diff --git a/sqlconnect/internal/mysql/mappings.go b/sqlconnect/internal/mysql/mappings.go index a2f30c9..7411467 100644 --- a/sqlconnect/internal/mysql/mappings.go +++ b/sqlconnect/internal/mysql/mappings.go @@ -33,6 +33,12 @@ const ( bigint = "BIGINT" ) +// mapping of database column types to rudder types +var columnTypeMappings = map[string]string{ + varchar: "string", + integer: "int", +} + func jsonRowMapper(databaseTypeName string, value interface{}) interface{} { if value == nil { return nil diff --git a/sqlconnect/internal/postgres/integration_test.go b/sqlconnect/internal/postgres/integration_test.go index e916510..75ce8ab 100644 --- a/sqlconnect/internal/postgres/integration_test.go +++ b/sqlconnect/internal/postgres/integration_test.go @@ -3,6 +3,7 @@ package postgres_test import ( "encoding/json" "strconv" + "strings" "testing" "github.com/ory/dockertest/v3" @@ -33,5 +34,5 @@ func TestPostgresDB(t *testing.T) { configJSON, err := json.Marshal(config) require.NoError(t, err, "it should be able to marshal config to json") - integrationtest.TestDatabaseScenarios(t, postgres.DatabaseType, configJSON) + integrationtest.TestDatabaseScenarios(t, postgres.DatabaseType, configJSON, strings.ToLower) } diff --git a/sqlconnect/internal/redshift/db.go b/sqlconnect/internal/redshift/db.go index 042284f..1bf3ab7 100644 --- a/sqlconnect/internal/redshift/db.go +++ b/sqlconnect/internal/redshift/db.go @@ -35,8 +35,11 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { lo.Ternary(config.RudderSchema != "", config.RudderSchema, defaultRudderSchema), base.WithColumnTypeMappings(columnTypeMappings), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.ListSchemas = func() (string, string) { + return "SELECT schema_name FROM svv_redshift_schemas", "schema_name" + } cmds.SchemaExists = func(schema string) string { - return fmt.Sprintf("SELECT has_schema_privilege((SELECT current_user), '%[1]s', 'usage')", schema) + return fmt.Sprintf("SELECT schema_name FROM svv_redshift_schemas WHERE schema_name = '%[1]s'", schema) } return cmds }), diff --git a/sqlconnect/internal/redshift/integration_test.go b/sqlconnect/internal/redshift/integration_test.go new file mode 100644 index 0000000..3ad3ea6 --- /dev/null +++ b/sqlconnect/internal/redshift/integration_test.go @@ -0,0 +1,19 @@ +package redshift_test + +import ( + "os" + "strings" + "testing" + + integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/redshift" +) + +func TestRedshiftDB(t *testing.T) { + configJSON, ok := os.LookupEnv("REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping redshift integration test due to lack of a test environment") + } + + integrationtest.TestDatabaseScenarios(t, redshift.DatabaseType, []byte(configJSON), strings.ToLower) +} diff --git a/sqlconnect/internal/snowflake/db.go b/sqlconnect/internal/snowflake/db.go index 6662e8b..2b9e042 100644 --- a/sqlconnect/internal/snowflake/db.go +++ b/sqlconnect/internal/snowflake/db.go @@ -38,21 +38,21 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { db, lo.Ternary(config.RudderSchema != "", config.RudderSchema, defaultRudderSchema), base.WithDialect(dialect{}), - base.WithColumnTypeMappings(columnTypeMappings), + base.WithColumnTypeMapper(columnTypeMapper), base.WithJsonRowMapper(jsonRowMapper), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { + cmds.ListSchemas = func() (string, string) { return "SHOW TERSE SCHEMAS", "name" } cmds.SchemaExists = func(schema string) string { return fmt.Sprintf("SHOW TERSE SCHEMAS LIKE '%[1]s'", schema) } - cmds.ListSchemas = func() (string, string) { return "SHOW TERSE SCHEMAS", "schema_name" } cmds.ListTables = func(schema string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SHOW TERSE TABLES IN SCHEMA %[1]s", schema), B: "name"}, + {A: fmt.Sprintf(`SHOW TERSE TABLES IN SCHEMA "%[1]s"`, schema), B: "name"}, } } cmds.ListTablesWithPrefix = func(schema, prefix string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SHOW TERSE TABLES IN SCHEMA %[1]s LIKE '%[2]s'", schema, prefix+"%"), B: "name"}, + {A: fmt.Sprintf(`SHOW TERSE TABLES LIKE '%[2]s' IN SCHEMA "%[1]s"`, schema, prefix+"%"), B: "name"}, } } cmds.TableExists = func(schema, table string) string { @@ -61,6 +61,9 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { cmds.ListColumns = func(schema, table string) (string, string, string) { return fmt.Sprintf(`DESCRIBE TABLE "%[1]s"."%[2]s"`, schema, table), "name", "type" } + cmds.RenameTable = func(schema, oldName, newName string) string { + return fmt.Sprintf(`ALTER TABLE %[1]s.%[2]s RENAME TO %[1]s.%[3]s`, schema, oldName, newName) + } return cmds }), ), diff --git a/sqlconnect/internal/snowflake/dialect_test.go b/sqlconnect/internal/snowflake/dialect_test.go new file mode 100644 index 0000000..ec66800 --- /dev/null +++ b/sqlconnect/internal/snowflake/dialect_test.go @@ -0,0 +1,30 @@ +package snowflake + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestDialect(t *testing.T) { + var d dialect + t.Run("format table", func(t *testing.T) { + formatted := d.FormatTableName("TaBle") + require.Equal(t, "TABLE", formatted, "table name should be uppercased") + }) + + t.Run("quote identifier", func(t *testing.T) { + quoted := d.QuoteIdentifier("column") + require.Equal(t, `"column"`, quoted, "column name should be quoted with double quotes") + }) + + t.Run("quote table", func(t *testing.T) { + quoted := d.QuoteTable(sqlconnect.NewRelationRef("table")) + require.Equal(t, `"table"`, quoted, "table name should be quoted with double quotes") + + quoted = d.QuoteTable(sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"))) + require.Equal(t, `"schema"."table"`, quoted, "schema and table name should be quoted with double quotes") + }) +} diff --git a/sqlconnect/internal/snowflake/integration_test.go b/sqlconnect/internal/snowflake/integration_test.go new file mode 100644 index 0000000..8c45744 --- /dev/null +++ b/sqlconnect/internal/snowflake/integration_test.go @@ -0,0 +1,19 @@ +package snowflake_test + +import ( + "os" + "strings" + "testing" + + integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/snowflake" +) + +func TestSnowflakeDB(t *testing.T) { + configJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping snowflake integration test due to lack of a test environment") + } + + integrationtest.TestDatabaseScenarios(t, snowflake.DatabaseType, []byte(configJSON), strings.ToUpper) +} diff --git a/sqlconnect/internal/snowflake/mappings.go b/sqlconnect/internal/snowflake/mappings.go index 027be28..000b537 100644 --- a/sqlconnect/internal/snowflake/mappings.go +++ b/sqlconnect/internal/snowflake/mappings.go @@ -3,8 +3,12 @@ package snowflake import ( "encoding/json" "fmt" + "regexp" "strconv" + "strings" "time" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" ) // mapping of database column types to rudder types @@ -17,7 +21,6 @@ var columnTypeMappings = map[string]string{ "BIGINT": "int", "SMALLINT": "int", "TINYINT": "int", - "FIXED": "float", "FLOAT": "float", "FLOAT4": "float", "FLOAT8": "float", @@ -42,6 +45,22 @@ var columnTypeMappings = map[string]string{ "VARIANT": "json", } +var re = regexp.MustCompile(`([^(]+) ?\(.*`) + +func columnTypeMapper(columnType base.ColumnType) string { + databaseTypeName := strings.ToUpper(re.ReplaceAllString(columnType.DatabaseTypeName(), "$1")) + if mappedType, ok := columnTypeMappings[strings.ToUpper(databaseTypeName)]; ok { + return mappedType + } + if databaseTypeName == "FIXED" { + if _, decimals, ok := columnType.DecimalSize(); ok && decimals > 0 { + return "float" + } + return "int" + } + return databaseTypeName +} + // check https://godoc.org/github.com/snowflakedb/gosnowflake#hdr-Supported_Data_Types for handling snowflake data types func jsonRowMapper(databaseTypeName string, value interface{}) interface{} { if value == nil { diff --git a/sqlconnect/internal/trino/db.go b/sqlconnect/internal/trino/db.go index 1b72c31..a663e97 100644 --- a/sqlconnect/internal/trino/db.go +++ b/sqlconnect/internal/trino/db.go @@ -38,7 +38,7 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { cmds.ListTables = func(schema string) []lo.Tuple2[string, string] { return []lo.Tuple2[string, string]{ - {A: fmt.Sprintf("SHOW TABLES FROM %s", schema), B: "tableName"}, + {A: fmt.Sprintf("SHOW TABLES FROM %[1]s", schema), B: "tableName"}, } } cmds.ListTablesWithPrefix = func(schema, prefix string) []lo.Tuple2[string, string] { @@ -46,6 +46,12 @@ func NewDB(configJSON json.RawMessage) (*DB, error) { {A: fmt.Sprintf("SHOW TABLES FROM %[1]s LIKE '%[2]s'", schema, prefix+"%"), B: "tableName"}, } } + cmds.TableExists = func(schema, table string) string { + return fmt.Sprintf("SHOW TABLES FROM %[1]s LIKE '%[2]s'", schema, table) + } + cmds.TruncateTable = func(table string) string { + return fmt.Sprintf("DELETE FROM %[1]s", table) + } return cmds }), ), diff --git a/sqlconnect/internal/trino/integration_test.go b/sqlconnect/internal/trino/integration_test.go new file mode 100644 index 0000000..69c65b8 --- /dev/null +++ b/sqlconnect/internal/trino/integration_test.go @@ -0,0 +1,19 @@ +package trino_test + +import ( + "os" + "strings" + "testing" + + integrationtest "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/integration_test" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/trino" +) + +func TestTrinoDB(t *testing.T) { + configJSON, ok := os.LookupEnv("TRINO_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping trino integration test due to lack of a test environment") + } + + integrationtest.TestDatabaseScenarios(t, trino.DatabaseType, []byte(configJSON), strings.ToLower) +} diff --git a/sqlconnect/internal/trino/mappings.go b/sqlconnect/internal/trino/mappings.go index 892d1f2..575f8c0 100644 --- a/sqlconnect/internal/trino/mappings.go +++ b/sqlconnect/internal/trino/mappings.go @@ -1,6 +1,10 @@ package trino -import "strings" +import ( + "strings" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" +) // mapping of database column types to rudder types var columnTypeMappings = map[string]string{ @@ -23,16 +27,16 @@ var columnTypeMappings = map[string]string{ "TIMESTAMP": "datetime", } -func columnTypeMapper(databaseTypeName string) string { +func columnTypeMapper(columnType base.ColumnType) string { + databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName()) if mappedType, ok := columnTypeMappings[databaseTypeName]; ok { return mappedType } - databaseTypeNameLower := strings.ToLower(databaseTypeName) - if strings.Contains(databaseTypeNameLower, "char") || strings.Contains(databaseTypeNameLower, "varchar") { + if strings.Contains(databaseTypeName, "CHAR") || strings.Contains(databaseTypeName, "VARCHAR") { return "string" - } else if strings.Contains(databaseTypeNameLower, "timestamp") { + } else if strings.Contains(databaseTypeName, "TIMESTAMP") { return "datetime" - } else if strings.Contains(databaseTypeNameLower, "decimal") { + } else if strings.Contains(databaseTypeName, "DECIMAL") { return "float" } return databaseTypeName diff --git a/sqlconnect/internal/util/validatehost.go b/sqlconnect/internal/util/validatehost.go index d6f79e2..240759c 100644 --- a/sqlconnect/internal/util/validatehost.go +++ b/sqlconnect/internal/util/validatehost.go @@ -5,6 +5,7 @@ import ( "net" ) +// ValidateHost checks if the hostname is resolvable and that it doesn't correspond to localhost. func ValidateHost(hostname string) error { addrs, err := net.LookupHost(hostname) if err != nil { diff --git a/sqlconnect/internal/util/validatehost_test.go b/sqlconnect/internal/util/validatehost_test.go new file mode 100644 index 0000000..fc2b38b --- /dev/null +++ b/sqlconnect/internal/util/validatehost_test.go @@ -0,0 +1,26 @@ +package util_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/util" +) + +func TestValidateHost(t *testing.T) { + t.Run("valid host", func(t *testing.T) { + err := util.ValidateHost("github.com") + require.NoError(t, err) + }) + + t.Run("invalid host", func(t *testing.T) { + err := util.ValidateHost("!@#$.$%^") + require.Error(t, err) + }) + + t.Run("localhost", func(t *testing.T) { + err := util.ValidateHost("localhost") + require.Error(t, err) + }) +} diff --git a/sqlconnect/def_query.go b/sqlconnect/querydef.go similarity index 100% rename from sqlconnect/def_query.go rename to sqlconnect/querydef.go diff --git a/sqlconnect/querydef_test.go b/sqlconnect/querydef_test.go new file mode 100644 index 0000000..b4014d0 --- /dev/null +++ b/sqlconnect/querydef_test.go @@ -0,0 +1,64 @@ +package sqlconnect_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestQueryDef(t *testing.T) { + t.Run("with columns", func(t *testing.T) { + table := sqlconnect.NewRelationRef("table") + q := sqlconnect.QueryDef{ + Table: &table, + Columns: []string{"col1", "col2"}, + Conditions: []*sqlconnect.QueryCondition{ + {Column: "col1", Operator: "=", Value: "'1'"}, + {Column: "col2", Operator: ">", Value: "2"}, + }, + OrderBy: &sqlconnect.QueryOrder{ + Column: "col1", + Order: "ASC", + }, + } + + sql := q.ToSQL(testDialect{}) + expected := `SELECT "col1","col2" FROM "table" WHERE "col1" = '1' AND "col2" > 2 ORDER BY "col1" ASC` + require.Equal(t, expected, sql, "query should be formatted correctly") + }) + + t.Run("without columns", func(t *testing.T) { + table := sqlconnect.NewRelationRef("table") + q := sqlconnect.QueryDef{ + Table: &table, + Conditions: []*sqlconnect.QueryCondition{ + {Column: "col1", Operator: "=", Value: "'1'"}, + {Column: "col2", Operator: ">", Value: "2"}, + }, + } + + sql := q.ToSQL(testDialect{}) + expected := `SELECT * FROM "table" WHERE "col1" = '1' AND "col2" > 2` + require.Equal(t, expected, sql, "query should be formatted correctly") + }) +} + +type testDialect struct{} + +func (d testDialect) FormatTableName(name string) string { + return name +} + +func (d testDialect) QuoteIdentifier(name string) string { + return fmt.Sprintf(`"%s"`, name) +} + +func (d testDialect) QuoteTable(relation sqlconnect.RelationRef) string { + if relation.Schema != "" { + return fmt.Sprintf(`"%s"."%s"`, relation.Schema, relation.Name) + } + return fmt.Sprintf(`"%s"`, relation.Name) +} diff --git a/sqlconnect/ref_relation.go b/sqlconnect/relationref.go similarity index 100% rename from sqlconnect/ref_relation.go rename to sqlconnect/relationref.go diff --git a/sqlconnect/ref_relationopts.go b/sqlconnect/relationref_opts.go similarity index 100% rename from sqlconnect/ref_relationopts.go rename to sqlconnect/relationref_opts.go diff --git a/sqlconnect/relationref_test.go b/sqlconnect/relationref_test.go new file mode 100644 index 0000000..592c04c --- /dev/null +++ b/sqlconnect/relationref_test.go @@ -0,0 +1,60 @@ +package sqlconnect_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestRelationRef(t *testing.T) { + t.Run("name", func(t *testing.T) { + ref := sqlconnect.NewRelationRef("table") + require.Equal(t, sqlconnect.RelationRef{Name: "table", Type: "table"}, ref) + require.Equal(t, "table", ref.String()) + + refJSON, _ := json.Marshal(ref) + var ref1 sqlconnect.RelationRef + err := ref1.UnmarshalJSON(refJSON) + require.NoError(t, err) + require.Equal(t, ref, ref1) + }) + + t.Run("name and schema", func(t *testing.T) { + ref := sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema")) + require.Equal(t, sqlconnect.RelationRef{Name: "table", Schema: "schema", Type: "table"}, ref) + require.Equal(t, "schema.table", ref.String()) + + refJSON, _ := json.Marshal(ref) + var ref1 sqlconnect.RelationRef + err := ref1.UnmarshalJSON(refJSON) + require.NoError(t, err) + require.Equal(t, ref, ref1) + }) + + t.Run("name and schema and catalog", func(t *testing.T) { + ref := sqlconnect.NewRelationRef("table", sqlconnect.WithSchema("schema"), sqlconnect.WithCatalog("catalog")) + require.Equal(t, sqlconnect.RelationRef{Name: "table", Schema: "schema", Catalog: "catalog", Type: "table"}, ref) + require.Equal(t, "catalog.schema.table", ref.String()) + + refJSON, _ := json.Marshal(ref) + var ref1 sqlconnect.RelationRef + err := ref1.UnmarshalJSON(refJSON) + require.NoError(t, err) + require.Equal(t, ref, ref1) + }) + + t.Run("view instead of table", func(t *testing.T) { + ref := sqlconnect.NewRelationRef("view", sqlconnect.WithRelationType(sqlconnect.ViewRelation)) + require.Equal(t, sqlconnect.RelationRef{Name: "view", Type: "view"}, ref) + }) + + t.Run("unmarshal without a type", func(t *testing.T) { + var ref sqlconnect.RelationRef + err := ref.UnmarshalJSON([]byte(`{"name":"table"}`)) + require.NoError(t, err) + require.Equal(t, sqlconnect.NewRelationRef("table"), ref) + }) +} diff --git a/sqlconnect/ref_schema.go b/sqlconnect/schemaref.go similarity index 100% rename from sqlconnect/ref_schema.go rename to sqlconnect/schemaref.go diff --git a/sqlconnect/schemaref_test.go b/sqlconnect/schemaref_test.go new file mode 100644 index 0000000..19df4e0 --- /dev/null +++ b/sqlconnect/schemaref_test.go @@ -0,0 +1,16 @@ +package sqlconnect_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect" +) + +func TestSchemaRef(t *testing.T) { + t.Run("string", func(t *testing.T) { + s := sqlconnect.SchemaRef{Name: "schema"} + require.Equal(t, "schema", s.String(), "schema name should be returned") + }) +}