Skip to content

Commit

Permalink
feat: support dry-run mode for auto-migrate
Browse files Browse the repository at this point in the history
Adds a dry-run mode for AutoMigrate for Spanner databases. The output
of a dry-run can be inspected and manually modified to include specific
Spanner features, such as interleaved tables or row deletion policies.
  • Loading branch information
olavloite committed Nov 18, 2024
1 parent fdc3305 commit 48f6233
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 10 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ Cloud Spanner supports the following data types in combination with `gorm`.
| bytes | []byte |


## AutoMigrate Dry Run
The Spanner `gorm` dialect supports dry-runs for auto-migration. Use this to get the
DDL statements that would be generated and executed by auto-migration. You can manually
verify and modify these statements to optimize your data model.

Example:

```go
tables := []interface{}{&singer{}, &album{}}

// Unwrap the underlying SpannerMigrator interface. This interface supports
// the `AutoMigrateDryRun` method, which does not actually execute the
// generated statements, and instead just returns these as an array.
m := db.Migrator()
migrator, ok := m.(spannergorm.SpannerMigrator)
if !ok {
return fmt.Errorf("unexpected migrator type: %v", m)
}
statements, err := migrator.AutoMigrateDryRun(tables...)
```

## Limitations
The Cloud Spanner `gorm` dialect has the following known limitations:

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go/longrunning v0.6.2
cloud.google.com/go/spanner v1.73.0
github.com/golang/protobuf v1.5.4
github.com/google/go-cmp v0.6.0
github.com/googleapis/go-sql-spanner v1.8.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.9.0
Expand Down
68 changes: 59 additions & 9 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package gorm
import (
"database/sql"
"fmt"
"slices"
"sort"
"strings"

"cloud.google.com/go/spanner"
spannerdriver "github.com/googleapis/go-sql-spanner"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
Expand All @@ -33,6 +36,7 @@ const (
type SpannerMigrator interface {
gorm.Migrator

AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error)
StartBatchDDL() error
RunBatch() error
AbortBatch() error
Expand All @@ -41,6 +45,7 @@ type SpannerMigrator interface {
type spannerMigrator struct {
migrator.Migrator
Dialector
dryRun bool
}

type spannerColumnType struct {
Expand All @@ -60,21 +65,48 @@ func (m spannerMigrator) CurrentDatabase() (name string) {
return ""
}

func (m spannerMigrator) AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error) {
return m.autoMigrate( /* dryRun = */ true, values...)
}

func (m spannerMigrator) AutoMigrate(values ...interface{}) error {
if !m.Dialector.Config.DisableAutoMigrateBatching {
_, err := m.autoMigrate( /* dryRun = */ false, values...)
return err
}

func (m spannerMigrator) autoMigrate(dryRun bool, values ...interface{}) ([]spanner.Statement, error) {
if dryRun || !m.Dialector.Config.DisableAutoMigrateBatching {
if err := m.StartBatchDDL(); err != nil {
return err
return nil, err
}
}
err := m.Migrator.AutoMigrate(values...)
if err == nil {
if m.Dialector.Config.DisableAutoMigrateBatching {
return nil
if !dryRun && m.Dialector.Config.DisableAutoMigrateBatching {
return nil, nil
} else if dryRun {
connPool := m.DB.Statement.ConnPool
conn, ok := connPool.(*sql.Conn)
if !ok {
return nil, fmt.Errorf("unexpected ConnPool type")
}
var statements []spanner.Statement
if err := conn.Raw(func(driverConn any) error {
spannerConn, ok := driverConn.(spannerdriver.SpannerConn)
if !ok {
return fmt.Errorf("dry-run is only supported for Spanner")
}
statements = spannerConn.GetBatchedStatements()
return nil
}); err != nil {
return nil, err
}
return statements, m.AbortBatch()
} else {
return m.RunBatch()
return nil, m.RunBatch()
}
}
return fmt.Errorf("unexpected return value type: %v", err)
return nil, fmt.Errorf("unexpected return value type: %v", err)
}

func (m spannerMigrator) StartBatchDDL() error {
Expand Down Expand Up @@ -132,6 +164,9 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error {
return err
}
f.DefaultValue = "GET_NEXT_SEQUENCE_VALUE(Sequence " + sequence + ")"
// Reset the default value to nothing after finishing migration.
//goland:noinspection GoDeferInLoop
defer func() { f.DefaultValue = "" }()
}
}
for _, dbName := range stmt.Schema.DBNames {
Expand All @@ -145,17 +180,32 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error {
}

// Indexes should always be created after the table, as Spanner does not support
// inline index creation.
for _, idx := range stmt.Schema.ParseIndexes() {
// inline index creation. Iterate over the indexes in a fixed order to make the
// script outcome deterministic.
indexes := stmt.Schema.ParseIndexes()
indexNames := make([]string, 0, len(indexes))
for name := range indexes {
indexNames = append(indexNames, name)
}
slices.Sort(indexNames)
for _, name := range indexNames {
idx := indexes[name]
defer func(value interface{}, name string) {
if errr == nil {
errr = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
}

for _, rel := range stmt.Schema.Relationships.Relations {
// Iterator over the relationships in a fixed order.
relationshipKeys := make([]string, 0, len(stmt.Schema.Relationships.Relations))
for key := range stmt.Schema.Relationships.Relations {
relationshipKeys = append(relationshipKeys, key)
}
slices.Sort(relationshipKeys)
for _, key := range relationshipKeys {
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
rel := stmt.Schema.Relationships.Relations[key]
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
Expand Down
26 changes: 25 additions & 1 deletion migrator_emulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/go-gorm-spanner/testutil"
"github.com/shopspring/decimal"
"gorm.io/datatypes"
Expand Down Expand Up @@ -107,7 +108,30 @@ func TestAutoMigrate_CreateDataModel(t *testing.T) {
if err != nil {
log.Fatal(err)
}
err = db.Migrator().AutoMigrate(&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{})
tables := []interface{}{&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{}}
statements, err := db.Migrator().(SpannerMigrator).AutoMigrateDryRun(tables...)
if diff := cmp.Diff(statements, []spanner.Statement{
{SQL: `CREATE SEQUENCE IF NOT EXISTS singers_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `singers` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence singers_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`first_name` STRING(MAX),`last_name` STRING(MAX),`full_name` STRING(MAX) AS (concat(coalesce(first_name, ''),' ',last_name)) STORED,`active` BOOL) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_singers_deleted_at` ON `singers`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS albums_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `albums` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence albums_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`title` STRING(MAX),`marketing_budget` BOOL,`release_date` date,`cover_picture` BYTES(MAX),`singer_id` INT64,CONSTRAINT `fk_singers_albums` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_albums_deleted_at` ON `albums`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS tracks_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `tracks` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence tracks_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`track_number` INT64,`title` STRING(MAX),`sample_rate` FLOAT64,`album_id` INT64,CONSTRAINT `fk_albums_tracks` FOREIGN KEY (`album_id`) REFERENCES `albums`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_tracks_deleted_at` ON `tracks`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS venues_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `venues` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence venues_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`description` JSON) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_venues_deleted_at` ON `venues`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS concerts_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `concerts` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence concerts_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`venue_id` INT64,`singer_id` INT64,`start_time` TIMESTAMP,`end_time` TIMESTAMP,CONSTRAINT `fk_singers_concerts` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`),CONSTRAINT `fk_venues_concerts` FOREIGN KEY (`venue_id`) REFERENCES `venues`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_concerts_time` ON `concerts`(`start_time`,`end_time`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_concerts_deleted_at` ON `concerts`(`deleted_at`)", Params: map[string]any{}},
}, cmp.AllowUnexported(spanner.Statement{})); diff != "" {
t.Errorf("auto-migrate statements mismatch: %v", diff)
}

err = db.Migrator().AutoMigrate(tables...)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 48f6233

Please sign in to comment.