Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dry-run mode for auto-migrate #124

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ Cloud Spanner supports the following data types in combination with `gorm`.
| bytes | []byte |


## AutoMigrate Dry Run
The Spanner `gorm` dialect supports dry-runs for auto-migration. Use this to get the
DDL statements that would be generated and executed by auto-migration. You can manually
verify and modify these statements to optimize your data model.

Example:

```go
tables := []interface{}{&singer{}, &album{}}

// Unwrap the underlying SpannerMigrator interface. This interface supports
// the `AutoMigrateDryRun` method, which does not actually execute the
// generated statements, and instead just returns these as an array.
m := db.Migrator()
migrator, ok := m.(spannergorm.SpannerMigrator)
if !ok {
return fmt.Errorf("unexpected migrator type: %v", m)
}
statements, err := migrator.AutoMigrateDryRun(tables...)
```

## Limitations
The Cloud Spanner `gorm` dialect has the following known limitations:

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go/longrunning v0.6.2
cloud.google.com/go/spanner v1.73.0
github.com/golang/protobuf v1.5.4
github.com/google/go-cmp v0.6.0
github.com/googleapis/go-sql-spanner v1.8.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.9.0
Expand Down
68 changes: 59 additions & 9 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package gorm
import (
"database/sql"
"fmt"
"slices"
"sort"
"strings"

"cloud.google.com/go/spanner"
spannerdriver "github.com/googleapis/go-sql-spanner"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
Expand All @@ -33,6 +36,7 @@ const (
type SpannerMigrator interface {
gorm.Migrator

AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error)
StartBatchDDL() error
RunBatch() error
AbortBatch() error
Expand All @@ -41,6 +45,7 @@ type SpannerMigrator interface {
type spannerMigrator struct {
migrator.Migrator
Dialector
dryRun bool
}

type spannerColumnType struct {
Expand All @@ -60,21 +65,48 @@ func (m spannerMigrator) CurrentDatabase() (name string) {
return ""
}

func (m spannerMigrator) AutoMigrateDryRun(values ...interface{}) ([]spanner.Statement, error) {
return m.autoMigrate( /* dryRun = */ true, values...)
}

func (m spannerMigrator) AutoMigrate(values ...interface{}) error {
if !m.Dialector.Config.DisableAutoMigrateBatching {
_, err := m.autoMigrate( /* dryRun = */ false, values...)
return err
}

func (m spannerMigrator) autoMigrate(dryRun bool, values ...interface{}) ([]spanner.Statement, error) {
if dryRun || !m.Dialector.Config.DisableAutoMigrateBatching {
if err := m.StartBatchDDL(); err != nil {
return err
return nil, err
}
}
err := m.Migrator.AutoMigrate(values...)
if err == nil {
if m.Dialector.Config.DisableAutoMigrateBatching {
return nil
if !dryRun && m.Dialector.Config.DisableAutoMigrateBatching {
return nil, nil
} else if dryRun {
connPool := m.DB.Statement.ConnPool
conn, ok := connPool.(*sql.Conn)
if !ok {
return nil, fmt.Errorf("unexpected ConnPool type")
}
var statements []spanner.Statement
if err := conn.Raw(func(driverConn any) error {
spannerConn, ok := driverConn.(spannerdriver.SpannerConn)
if !ok {
return fmt.Errorf("dry-run is only supported for Spanner")
}
statements = spannerConn.GetBatchedStatements()
return nil
}); err != nil {
return nil, err
}
return statements, m.AbortBatch()
} else {
return m.RunBatch()
return nil, m.RunBatch()
}
}
return fmt.Errorf("unexpected return value type: %v", err)
return nil, fmt.Errorf("unexpected return value type: %v", err)
}

func (m spannerMigrator) StartBatchDDL() error {
Expand Down Expand Up @@ -132,6 +164,9 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error {
return err
}
f.DefaultValue = "GET_NEXT_SEQUENCE_VALUE(Sequence " + sequence + ")"
// Reset the default value to nothing after finishing migration.
//goland:noinspection GoDeferInLoop
defer func() { f.DefaultValue = "" }()
}
}
for _, dbName := range stmt.Schema.DBNames {
Expand All @@ -145,17 +180,32 @@ func (m spannerMigrator) CreateTable(values ...interface{}) error {
}

// Indexes should always be created after the table, as Spanner does not support
// inline index creation.
for _, idx := range stmt.Schema.ParseIndexes() {
// inline index creation. Iterate over the indexes in a fixed order to make the
// script outcome deterministic.
indexes := stmt.Schema.ParseIndexes()
indexNames := make([]string, 0, len(indexes))
for name := range indexes {
indexNames = append(indexNames, name)
}
slices.Sort(indexNames)
for _, name := range indexNames {
idx := indexes[name]
defer func(value interface{}, name string) {
if errr == nil {
errr = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
}

for _, rel := range stmt.Schema.Relationships.Relations {
// Iterator over the relationships in a fixed order.
relationshipKeys := make([]string, 0, len(stmt.Schema.Relationships.Relations))
for key := range stmt.Schema.Relationships.Relations {
relationshipKeys = append(relationshipKeys, key)
}
slices.Sort(relationshipKeys)
for _, key := range relationshipKeys {
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
rel := stmt.Schema.Relationships.Relations[key]
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
Expand Down
26 changes: 25 additions & 1 deletion migrator_emulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/go-gorm-spanner/testutil"
"github.com/shopspring/decimal"
"gorm.io/datatypes"
Expand Down Expand Up @@ -107,7 +108,30 @@ func TestAutoMigrate_CreateDataModel(t *testing.T) {
if err != nil {
log.Fatal(err)
}
err = db.Migrator().AutoMigrate(&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{})
tables := []interface{}{&Singer{}, &Album{}, &Track{}, &Venue{}, &Concert{}}
statements, err := db.Migrator().(SpannerMigrator).AutoMigrateDryRun(tables...)
if diff := cmp.Diff(statements, []spanner.Statement{
{SQL: `CREATE SEQUENCE IF NOT EXISTS singers_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `singers` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence singers_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`first_name` STRING(MAX),`last_name` STRING(MAX),`full_name` STRING(MAX) AS (concat(coalesce(first_name, ''),' ',last_name)) STORED,`active` BOOL) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_singers_deleted_at` ON `singers`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS albums_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `albums` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence albums_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`title` STRING(MAX),`marketing_budget` BOOL,`release_date` date,`cover_picture` BYTES(MAX),`singer_id` INT64,CONSTRAINT `fk_singers_albums` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_albums_deleted_at` ON `albums`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS tracks_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `tracks` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence tracks_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`track_number` INT64,`title` STRING(MAX),`sample_rate` FLOAT64,`album_id` INT64,CONSTRAINT `fk_albums_tracks` FOREIGN KEY (`album_id`) REFERENCES `albums`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_tracks_deleted_at` ON `tracks`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS venues_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `venues` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence venues_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`description` JSON) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_venues_deleted_at` ON `venues`(`deleted_at`)", Params: map[string]any{}},
{SQL: `CREATE SEQUENCE IF NOT EXISTS concerts_seq OPTIONS (sequence_kind = "bit_reversed_positive")`, Params: map[string]any{}},
{SQL: "CREATE TABLE `concerts` (`id` INT64 DEFAULT (GET_NEXT_SEQUENCE_VALUE(Sequence concerts_seq)),`created_at` TIMESTAMP,`updated_at` TIMESTAMP,`deleted_at` TIMESTAMP,`name` STRING(MAX),`venue_id` INT64,`singer_id` INT64,`start_time` TIMESTAMP,`end_time` TIMESTAMP,CONSTRAINT `fk_singers_concerts` FOREIGN KEY (`singer_id`) REFERENCES `singers`(`id`),CONSTRAINT `fk_venues_concerts` FOREIGN KEY (`venue_id`) REFERENCES `venues`(`id`)) PRIMARY KEY (`id`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_concerts_time` ON `concerts`(`start_time`,`end_time`)", Params: map[string]any{}},
{SQL: "CREATE INDEX `idx_concerts_deleted_at` ON `concerts`(`deleted_at`)", Params: map[string]any{}},
}, cmp.AllowUnexported(spanner.Statement{})); diff != "" {
t.Errorf("auto-migrate statements mismatch: %v", diff)
}

err = db.Migrator().AutoMigrate(tables...)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading