Skip to content

Commit

Permalink
Merge pull request #1 from brunotm/no-transaction
Browse files Browse the repository at this point in the history
migrate: Add the ability to disable transactions for single statement migrations
  • Loading branch information
brunotm authored Oct 17, 2021
2 parents 29b03bb + be13cf7 commit 356bd07
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 96 deletions.
149 changes: 71 additions & 78 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package migrate

import (
"bufio"
"context"
"database/sql"
"fmt"
Expand All @@ -14,11 +13,42 @@ import (
"time"
)

var (
// StdLog is the log.Printf function from the standard library
StdLog = log.Printf

// 0001_initial_schema.apply.sql
// 0001_initial_schema.discard.sql
migrationRegexp = regexp.MustCompile(`(\d+)_(\w+)\.(apply|discard)\.sql`)
options = &sql.TxOptions{Isolation: sql.LevelSerializable}

versionQuery = "SELECT version, date, name FROM migrations ORDER BY date DESC LIMIT 1"

migration0 = &Migration{
Version: 0,
Name: "create_migrations_table",
Apply: Statements{
NoTx: false,
Statements: []string{
`CREATE TABLE IF NOT EXISTS migrations (date timestamp NOT NULL, version bigint NOT NULL, name varchar(512) NOT NULL, PRIMARY KEY (date,version))`},
},
Discard: Statements{
NoTx: false,
Statements: []string{`DROP TABLE IF EXISTS migrations CASCADE`},
},
}
)

// Executor executes statements in a database
type Executor interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

// Logger function signature
type Logger func(s string, args ...interface{})

// StdLog is the log.Printf function from the standard library
var StdLog = log.Printf
// nopLogger does notting
func nopLogger(_ string, _ ...interface{}) {}

// Migrate manages database migrations
type Migrate struct {
Expand All @@ -28,6 +58,27 @@ type Migrate struct {
migrations map[int64]*Migration
}

// Migration represents a database migration apply and discard statements
type Migration struct {
Version int64
Name string
Apply Statements
Discard Statements
}

// Statements are set of SQL statements that either apply or discard a migration
type Statements struct {
NoTx bool
Statements []string
}

// Version represents a migration version and its metadata
type Version struct {
Version int64
Date time.Time
Name string
}

// New creates a new Migrate with the given database and versions.
//
// If the provided logger function is not `nil` additional information will be logged during the
Expand Down Expand Up @@ -118,12 +169,12 @@ func NewWithFiles(db *sql.DB, files fs.FS, logger Logger) (m *Migrate, err error

switch match[3] {
case "apply":
mig.Apply = string(source)
mig.Apply, err = parseStatement(source)
case "discard":
mig.Discard = string(source)
mig.Discard, err = parseStatement(source)
}

return nil
return err
})

if err != nil {
Expand Down Expand Up @@ -246,70 +297,48 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err
}
}

var stmt string
var raw string

var statements Statements
switch discard {
case false:
if mig.Version != current.Version+1 {
return fmt.Errorf(
"migrate: wrong sequence number, current: %d, proposed: %d, discard: %t",
current.Version, mig.Version, discard)
}
raw = mig.Apply
statements = mig.Apply

case true:
if mig.Version != current.Version {
return fmt.Errorf(
"migrate: wrong sequence number, current: %d, proposed: %d, discard: %t",
current.Version, mig.Version, discard)
}
raw = mig.Discard
}
statements = mig.Discard

if raw == "" {
return nil
}

scanner := bufio.NewScanner(strings.NewReader(raw))
for scanner.Scan() {
line := scanner.Text()
for x := 0; x < len(statements.Statements); x++ {
m.logger("migrate: %s, discard: %t, transaction: %t, statement: %s", mig.Name, discard, !statements.NoTx, statements.Statements[x])

if strings.HasPrefix(line, "--") {
continue
}

if line[len(line)-1] == ';' {
if stmt != "" {
stmt += " "
switch statements.NoTx {
case false:
if _, err := tx.ExecContext(ctx, statements.Statements[x]); err != nil {
return err
}
stmt += line[:len(line)-1]

m.logger("migrate: %s, discard: %t, statement: %s", mig.Name, discard, stmt)
if _, err := tx.ExecContext(ctx, stmt); err != nil {
case true:
if _, err := m.db.ExecContext(ctx, statements.Statements[x]); err != nil {
return err
}

stmt = ""
continue
}

if stmt != "" {
stmt += " "
}
stmt += line
}

if stmt != "" {
m.logger("migrate: %s, discard: %t, statement: %s", mig.Name, discard, stmt)
if _, err := tx.ExecContext(ctx, stmt); err != nil {
return err
}
}

// set the current version after applying the migration
mig = m.migrations[mig.Version]
if discard {
mig = m.migrations[mig.Version-1]
}

if mig != nil {
if err = m.set(ctx, tx, mig); err != nil {
return err
Expand All @@ -318,39 +347,3 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err

return tx.Commit()
}

func nopLogger(_ string, _ ...interface{}) {}

type Migration struct {
Version int64
Name string
Apply string
Discard string
}

type Version struct {
Version int64
Date time.Time
Name string
}

var (
// 0001_initial_schema.apply.sql
// 0001_initial_schema.discard.sql
migrationRegexp = regexp.MustCompile(`(\d+)_(\w+)\.(apply|discard)\.sql`)
options = &sql.TxOptions{Isolation: sql.LevelSerializable}

versionQuery = "SELECT version, date, name FROM migrations ORDER BY date DESC LIMIT 1"

migration0 = &Migration{
Version: 0,
Name: "create_migrations_table",
Apply: `CREATE TABLE IF NOT EXISTS migrations (
date timestamp NOT NULL,
version bigint NOT NULL,
name varchar(512) NOT NULL,
PRIMARY KEY (date,version)
)`,
Discard: `DROP TABLE IF EXISTS migrations CASCADE`,
}
)
10 changes: 5 additions & 5 deletions migrate/migrate_down_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestMigrationDown(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration4.Version, time.Now(), migration4.Name),
)
mock.ExpectExec(migration4.Discard).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration4.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(3, NOW(), 'roles_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
Expand All @@ -38,7 +38,7 @@ func TestMigrationDown(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration3.Version, time.Now(), migration3.Name),
)
mock.ExpectExec(migration3.Discard).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration3.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(2, NOW(), 'users_email_index')`).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
Expand All @@ -48,7 +48,7 @@ func TestMigrationDown(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration2.Version, time.Now(), migration2.Name),
)
mock.ExpectExec(migration2.Discard).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration2.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(1, NOW(), 'users_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
Expand All @@ -58,7 +58,7 @@ func TestMigrationDown(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration1.Version, time.Now(), migration1.Name),
)
mock.ExpectExec(migration1.Discard).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration1.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(0, NOW(), 'create_migrations_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
Expand All @@ -68,7 +68,7 @@ func TestMigrationDown(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration0.Version, time.Now(), migration0.Name),
)
mock.ExpectExec(migration0.Discard).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration0.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()

m, err := New(mdb, StdLog, migrations)
Expand Down
43 changes: 30 additions & 13 deletions migrate/migrate_up_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestMigrationUp(t *testing.T) {
mock.ExpectQuery(versionQuery).WillReturnError(fmt.Errorf("relation does not exist"))
mock.ExpectRollback()
mock.ExpectBegin()
mock.ExpectExec(migration0.Apply).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(migration0.Apply.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(0, NOW(), 'create_migrations_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()
Expand All @@ -37,7 +37,7 @@ func TestMigrationUp(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration0.Version, time.Now(), migration0.Name),
)
mock.ExpectExec(migration1.Apply).
mock.ExpectExec(migration1.Apply.Statements[0]).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(1, NOW(), 'users_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
Expand All @@ -49,7 +49,7 @@ func TestMigrationUp(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration1.Version, time.Now(), migration1.Name),
)
mock.ExpectExec(migration2.Apply).
mock.ExpectExec(migration2.Apply.Statements[0]).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(2, NOW(), 'users_email_index')`).
WillReturnResult(sqlmock.NewResult(0, 1))
Expand All @@ -61,7 +61,7 @@ func TestMigrationUp(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration2.Version, time.Now(), migration2.Name),
)
mock.ExpectExec(migration3.Apply).
mock.ExpectExec(migration3.Apply.Statements[0]).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(3, NOW(), 'roles_table')`).
WillReturnResult(sqlmock.NewResult(0, 1))
Expand All @@ -73,7 +73,7 @@ func TestMigrationUp(t *testing.T) {
sqlmock.NewRows([]string{"date", "version", "name"}).
AddRow(migration3.Version, time.Now(), migration3.Name),
)
mock.ExpectExec(migration4.Apply).
mock.ExpectExec(migration4.Apply.Statements[0]).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(4, NOW(), 'user_roles_fk')`).
WillReturnResult(sqlmock.NewResult(0, 1))
Expand All @@ -99,25 +99,42 @@ var (
migration1 = &Migration{
Version: 1,
Name: "users_table",
Apply: "CREATE TABLE IF NOT EXISTS users(id text, name text, email text, role text, PRIMARY KEY (id))",
Discard: "DROP TABLE IF EXISTS users CASCADE",
Apply: Statements{
NoTx: true,
Statements: []string{"CREATE TABLE IF NOT EXISTS users(id text, name text, email text, role text, PRIMARY KEY (id))"},
},
Discard: Statements{
Statements: []string{"DROP TABLE IF EXISTS users CASCADE"},
},
}
migration2 = &Migration{
Version: 2,
Name: "users_email_index",
Apply: "CREATE INDEX IF NOT EXISTS ix_users_email ON users (email)",
Discard: "DROP INDEX IF EXISTS ix_users_email CASCADE",
Apply: Statements{
Statements: []string{"CREATE INDEX IF NOT EXISTS ix_users_email ON users (email)"},
},
Discard: Statements{
Statements: []string{"DROP INDEX IF EXISTS ix_users_email CASCADE"},
},
}
migration3 = &Migration{
Version: 3,
Name: "roles_table",
Apply: "CREATE TABLE IF NOT EXISTS roles(id text, name text, properties jsonb NOT NULL DEFAULT '{}'::jsonb, PRIMARY KEY (id))",
Discard: "DROP TABLE IF EXISTS roles CASCADE",
Apply: Statements{
Statements: []string{"CREATE TABLE IF NOT EXISTS roles(id text, name text, properties jsonb NOT NULL DEFAULT '{}'::jsonb, PRIMARY KEY (id))"},
},
Discard: Statements{
Statements: []string{"DROP TABLE IF EXISTS roles CASCADE"},
},
}
migration4 = &Migration{
Version: 4,
Name: "user_roles_fk",
Apply: "ALTER TABLE users ADD CONSTRAINT roles_fk FOREIGN KEY (role) REFERENCES roles (id)",
Discard: "ALTER TABLE users DROP CONSTRAINT roles_fk CASCADE",
Apply: Statements{
Statements: []string{"ALTER TABLE users ADD CONSTRAINT roles_fk FOREIGN KEY (role) REFERENCES roles (id)"},
},
Discard: Statements{
Statements: []string{"ALTER TABLE users DROP CONSTRAINT roles_fk CASCADE"},
},
}
)
Loading

0 comments on commit 356bd07

Please sign in to comment.