diff --git a/README.md b/README.md index 39aeccc..eab797d 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,27 @@ Cloud Spanner supports the following data types in combination with `gorm`. | bytes | []byte | +## AutoMigrate Dry Run +The Spanner `gorm` dialect supports dry-runs for auto-migration. Use this to get the +DDL statements that would be generated and executed by auto-migration. You can manually +verify and modify these statements to optimize your data model. + +Example: + +```go +tables := []interface{}{&singer{}, &album{}} + +// Unwrap the underlying SpannerMigrator interface. This interface supports +// the `AutoMigrateDryRun` method, which does not actually execute the +// generated statements, and instead just returns these as an array. +m := db.Migrator() +migrator, ok := m.(spannergorm.SpannerMigrator) +if !ok { + return fmt.Errorf("unexpected migrator type: %v", m) +} +statements, err := migrator.AutoMigrateDryRun(tables...) +``` + ## Limitations The Cloud Spanner `gorm` dialect has the following known limitations: diff --git a/go.mod b/go.mod index 0bbb521..caa6dae 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( cloud.google.com/go/longrunning v0.6.2 cloud.google.com/go/spanner v1.73.0 github.com/golang/protobuf v1.5.4 + github.com/google/go-cmp v0.6.0 github.com/googleapis/go-sql-spanner v1.8.0 github.com/shopspring/decimal v1.4.0 github.com/stretchr/testify v1.9.0 diff --git a/migrator.go b/migrator.go index 609adbd..d7b115c 100644 --- a/migrator.go +++ b/migrator.go @@ -17,9 +17,12 @@ package gorm import ( "database/sql" "fmt" + "slices" "sort" "strings" + "cloud.google.com/go/spanner" + spannerdriver "github.com/googleapis/go-sql-spanner" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" @@ -33,6 +36,7 @@ const ( type SpannerMigrator interface { gorm.Migrator + AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error) StartBatchDDL() error RunBatch() error AbortBatch() error @@ -41,6 +45,7 @@ type SpannerMigrator interface { type spannerMigrator struct { migrator.Migrator Dialector + dryRun bool } type spannerColumnType struct { @@ -60,21 +65,48 @@ func (m spannerMigrator) CurrentDatabase() (name string) { return "" } +func (m spannerMigrator) AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error) { + return m.autoMigrate( /* dryRun = */ true, values...) +} + func (m spannerMigrator) AutoMigrate(values ...interface{}) error { - if !m.Dialector.Config.DisableAutoMigrateBatching { + _, err := m.autoMigrate( /* dryRun = */ false, values...) + return err +} + +func (m spannerMigrator) autoMigrate(dryRun bool, values ...interface{}) ([]spanner.Statement, error) { + if dryRun || !m.Dialector.Config.DisableAutoMigrateBatching { if err := m.StartBatchDDL(); err != nil { - return err + return nil, err } } err := m.Migrator.AutoMigrate(values...) if err == nil { - if m.Dialector.Config.DisableAutoMigrateBatching { - return nil + if !dryRun && m.Dialector.Config.DisableAutoMigrateBatching { + return nil, nil + } else if dryRun { + connPool := m.DB.Statement.ConnPool + conn, ok := connPool.(*sql.Conn) + if !ok { + return nil, fmt.Errorf("unexpected ConnPool type") + } + var statements []spanner.Statement + if err := conn.Raw(func(driverConn any) error { + spannerConn, ok := driverConn.(spannerdriver.SpannerConn) + if !ok { + return fmt.Errorf("dry-run is only supported for Spanner") + } + statements = spannerConn.GetBatchedStatements() + return nil + }); err != nil { + return nil, err + } + return statements, m.AbortBatch() } else { - return m.RunBatch() + return nil, m.RunBatch() } } - return fmt.Errorf("unexpected return value type: %v", err) + return nil, fmt.Errorf("unexpected return value type: %v", err) } func (m spannerMigrator) StartBatchDDL() error { @@ -132,6 +164,9 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error { return err } f.DefaultValue = "GET_NEXT_SEQUENCE_VALUE(Sequence " + sequence + ")" + // Reset the default value to nothing after finishing migration. + //goland:noinspection GoDeferInLoop + defer func() { f.DefaultValue = "" }() } } for _, dbName := range stmt.Schema.DBNames { @@ -145,8 +180,16 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error { } // Indexes should always be created after the table, as Spanner does not support - // inline index creation. - for _, idx := range stmt.Schema.ParseIndexes() { + // inline index creation. Iterate over the indexes in a fixed order to make the + // script outcome deterministic. + indexes := stmt.Schema.ParseIndexes() + indexNames := make([]string, 0, len(indexes)) + for name := range indexes { + indexNames = append(indexNames, name) + } + slices.Sort(indexNames) + for _, name := range indexNames { + idx := indexes[name] defer func(value interface{}, name string) { if errr == nil { errr = tx.Migrator().CreateIndex(value, name) @@ -154,8 +197,15 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error { }(value, idx.Name) } - for _, rel := range stmt.Schema.Relationships.Relations { + // Iterator over the relationships in a fixed order. + relationshipKeys := make([]string, 0, len(stmt.Schema.Relationships.Relations)) + for key := range stmt.Schema.Relationships.Relations { + relationshipKeys = append(relationshipKeys, key) + } + slices.Sort(relationshipKeys) + for _, key := range relationshipKeys { if !m.DB.DisableForeignKeyConstraintWhenMigrating { + rel := stmt.Schema.Relationships.Relations[key] if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { sql, vars := buildConstraint(constraint) diff --git a/migrator_emulator_test.go b/migrator_emulator_test.go index 4862aa9..d74f057 100644 --- a/migrator_emulator_test.go +++ b/migrator_emulator_test.go @@ -25,6 +25,7 @@ import ( "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "github.com/google/go-cmp/cmp" "github.com/googleapis/go-gorm-spanner/testutil" "github.com/shopspring/decimal" "gorm.io/datatypes" @@ -107,7 +108,30 @@ func TestAutoMigrate_CreateDataModel(t *testing.T) { if err != nil { log.Fatal(err) } - err = db.Migrator().AutoMigrate(&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{}) + tables := []interface{}{&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{}} + statements, err := db.Migrator().(SpannerMigrator).AutoMigrateDryRun(tables...) + if diff := cmp.Diff(statements, []spanner.Statement{ + {SQL: `CREATE SEQUENCE IF NOT EXISTS singers_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}}, + {SQL: "CREATE TABLE `singers` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence singers_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`first_name` STRING(MAX),`last_name` STRING(MAX),`full_name` STRING(MAX) AS (concat(coalesce(first_name, ''),' ',last_name)) STORED,`active` BOOL) PRIMARY KEY (`id`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_singers_deleted_at` ON `singers`(`deleted_at`)", Params: map[string]any{}}, + {SQL: `CREATE SEQUENCE IF NOT EXISTS albums_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}}, + {SQL: "CREATE TABLE `albums` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence albums_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`title` STRING(MAX),`marketing_budget` BOOL,`release_date` date,`cover_picture` BYTES(MAX),`singer_id` INT64,CONSTRAINT `fk_singers_albums` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_albums_deleted_at` ON `albums`(`deleted_at`)", Params: map[string]any{}}, + {SQL: `CREATE SEQUENCE IF NOT EXISTS tracks_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}}, + {SQL: "CREATE TABLE `tracks` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence tracks_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`track_number` INT64,`title` STRING(MAX),`sample_rate` FLOAT64,`album_id` INT64,CONSTRAINT `fk_albums_tracks` FOREIGN KEY (`album_id`) REFERENCES `albums`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_tracks_deleted_at` ON `tracks`(`deleted_at`)", Params: map[string]any{}}, + {SQL: `CREATE SEQUENCE IF NOT EXISTS venues_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}}, + {SQL: "CREATE TABLE `venues` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence venues_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`description` JSON) PRIMARY KEY (`id`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_venues_deleted_at` ON `venues`(`deleted_at`)", Params: map[string]any{}}, + {SQL: `CREATE SEQUENCE IF NOT EXISTS concerts_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}}, + {SQL: "CREATE TABLE `concerts` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence concerts_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`venue_id` INT64,`singer_id` INT64,`start_time` TIMESTAMP,`end_time` TIMESTAMP,CONSTRAINT `fk_singers_concerts` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`),CONSTRAINT `fk_venues_concerts` FOREIGN KEY (`venue_id`) REFERENCES `venues`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_concerts_time` ON `concerts`(`start_time`,`end_time`)", Params: map[string]any{}}, + {SQL: "CREATE INDEX `idx_concerts_deleted_at` ON `concerts`(`deleted_at`)", Params: map[string]any{}}, + }, cmp.AllowUnexported(spanner.Statement{})); diff != "" { + t.Errorf("auto-migrate statements mismatch: %v", diff) + } + + err = db.Migrator().AutoMigrate(tables...) if err != nil { t.Fatal(err) }