Skip to content

Commit

Permalink
fix non-tx Migrate using a txn + commit for each migration
Browse files Browse the repository at this point in the history
  • Loading branch information
bgentry committed Sep 20, 2024
1 parent e3b3b55 commit 30d0691
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions rivermigrate/river_migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,16 +297,15 @@ func (m *Migrator[TTx]) GetVersion(version int) (Migration, error) {
// // handle error
// }
func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *MigrateOpts) (*MigrateResult, error) {
return dbutil.WithTxV(ctx, m.driver.GetExecutor(), func(ctx context.Context, exec riverdriver.ExecutorTx) (*MigrateResult, error) {
switch direction {
case DirectionDown:
return m.migrateDown(ctx, exec, direction, opts)
case DirectionUp:
return m.migrateUp(ctx, exec, direction, opts)
}
exec := m.driver.GetExecutor()
switch direction {
case DirectionDown:
return m.migrateDown(ctx, exec, direction, opts)
case DirectionUp:
return m.migrateUp(ctx, exec, direction, opts)
}

panic("invalid direction: " + direction)
})
panic("invalid direction: " + direction)
}

// Migrate migrates the database in the given direction (up or down). The opts
Expand Down Expand Up @@ -560,10 +559,22 @@ func (m *Migrator[TTx]) applyMigrations(ctx context.Context, exec riverdriver.Ex

if !opts.DryRun {
start := time.Now()
_, err := exec.Exec(ctx, sql)

// Similar to ActiveRecord migrations, we wrap each individual migration
// in its own transaction. Without this, certain migrations that require
// a commit on a preexisting operation (such as adding an enum value to be
// used in an immutable function) cannot succeed.
err := dbutil.WithTx(ctx, exec, func(ctx context.Context, exec riverdriver.ExecutorTx) error {
_, err := exec.Exec(ctx, sql)
if err != nil {
return fmt.Errorf("error applying version %03d [%s]: %w",
versionBundle.Version, strings.ToUpper(string(direction)), err)
}
return nil
})

if err != nil {
return nil, fmt.Errorf("error applying version %03d [%s]: %w",
versionBundle.Version, strings.ToUpper(string(direction)), err)
return nil, err
}
duration = time.Since(start)
}
Expand Down

0 comments on commit 30d0691

Please sign in to comment.