From 6462553b5ae7a8c08a16316576c2bb7cea96ffd7 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Tue, 12 Oct 2021 14:47:52 +0100 Subject: [PATCH 1/5] migrate: allow transactions to be disabled for single statement migrations --- migrate/migrate.go | 154 +++++++++++++++++------------------ migrate/migrate_down_test.go | 10 +-- migrate/migrate_up_test.go | 43 +++++++--- migrate/parse.go | 45 ++++++++++ 4 files changed, 155 insertions(+), 97 deletions(-) create mode 100644 migrate/parse.go diff --git a/migrate/migrate.go b/migrate/migrate.go index 85f0686..52e6cee 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -1,7 +1,6 @@ package migrate import ( - "bufio" "context" "database/sql" "fmt" @@ -14,11 +13,47 @@ import ( "time" ) +var ( + // StdLog is the log.Printf function from the standard library + StdLog = log.Printf + + // 0001_initial_schema.apply.sql + // 0001_initial_schema.discard.sql + migrationRegexp = regexp.MustCompile(`(\d+)_(\w+)\.(apply|discard)\.sql`) + options = &sql.TxOptions{Isolation: sql.LevelSerializable} + + versionQuery = "SELECT version, date, name FROM migrations ORDER BY date DESC LIMIT 1" + + migration0 = &Migration{ + Version: 0, + Name: "create_migrations_table", + Apply: Statements{ + NoTx: false, + Statements: []string{ + `CREATE TABLE IF NOT EXISTS migrations ( + date timestamp NOT NULL, + version bigint NOT NULL, + name varchar(512) NOT NULL, + PRIMARY KEY (date,version) + )`}, + }, + Discard: Statements{ + NoTx: false, + Statements: []string{`DROP TABLE IF EXISTS migrations CASCADE`}, + }, + } +) + +// Executor executes statements in a database +type Executor interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // Logger function signature type Logger func(s string, args ...interface{}) -// StdLog is the log.Printf function from the standard library -var StdLog = log.Printf +// nopLogger does notting +func nopLogger(_ string, _ ...interface{}) {} // Migrate manages database migrations type Migrate struct { @@ -28,6 +63,27 @@ type Migrate struct { migrations map[int64]*Migration } +// Migration represents a database migration apply and discard statements +type Migration struct { + Version int64 + Name string + Apply Statements + Discard Statements +} + +// Statements are set of SQL statements that either apply or discard a migration +type Statements struct { + NoTx bool + Statements []string +} + +// Version represents a migration version and its metadata +type Version struct { + Version int64 + Date time.Time + Name string +} + // New creates a new Migrate with the given database and versions. // // If the provided logger function is not `nil` additional information will be logged during the @@ -118,12 +174,12 @@ func NewWithFiles(db *sql.DB, files fs.FS, logger Logger) (m *Migrate, err error switch match[3] { case "apply": - mig.Apply = string(source) + mig.Apply, err = parseStatement(source) case "discard": - mig.Discard = string(source) + mig.Discard, err = parseStatement(source) } - return nil + return err }) if err != nil { @@ -246,9 +302,7 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err } } - var stmt string - var raw string - + var statements Statements switch discard { case false: if mig.Version != current.Version+1 { @@ -256,60 +310,38 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err "migrate: wrong sequence number, current: %d, proposed: %d, discard: %t", current.Version, mig.Version, discard) } - raw = mig.Apply + statements = mig.Apply + case true: if mig.Version != current.Version { return fmt.Errorf( "migrate: wrong sequence number, current: %d, proposed: %d, discard: %t", current.Version, mig.Version, discard) } - raw = mig.Discard - } + statements = mig.Discard - if raw == "" { - return nil } - scanner := bufio.NewScanner(strings.NewReader(raw)) - for scanner.Scan() { - line := scanner.Text() - - if strings.HasPrefix(line, "--") { - continue - } - - if line[len(line)-1] == ';' { - if stmt != "" { - stmt += " " + for x := 0; x < len(statements.Statements); x++ { + switch statements.NoTx { + case false: + if _, err := tx.ExecContext(ctx, statements.Statements[x]); err != nil { + return err } - stmt += line[:len(line)-1] - m.logger("migrate: %s, discard: %t, statement: %s", mig.Name, discard, stmt) - if _, err := tx.ExecContext(ctx, stmt); err != nil { + case true: + if _, err := m.db.ExecContext(ctx, statements.Statements[x]); err != nil { return err } - - stmt = "" - continue - } - - if stmt != "" { - stmt += " " - } - stmt += line - } - - if stmt != "" { - m.logger("migrate: %s, discard: %t, statement: %s", mig.Name, discard, stmt) - if _, err := tx.ExecContext(ctx, stmt); err != nil { - return err } } + // set the current version after applying the migration mig = m.migrations[mig.Version] if discard { mig = m.migrations[mig.Version-1] } + if mig != nil { if err = m.set(ctx, tx, mig); err != nil { return err @@ -318,39 +350,3 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err return tx.Commit() } - -func nopLogger(_ string, _ ...interface{}) {} - -type Migration struct { - Version int64 - Name string - Apply string - Discard string -} - -type Version struct { - Version int64 - Date time.Time - Name string -} - -var ( - // 0001_initial_schema.apply.sql - // 0001_initial_schema.discard.sql - migrationRegexp = regexp.MustCompile(`(\d+)_(\w+)\.(apply|discard)\.sql`) - options = &sql.TxOptions{Isolation: sql.LevelSerializable} - - versionQuery = "SELECT version, date, name FROM migrations ORDER BY date DESC LIMIT 1" - - migration0 = &Migration{ - Version: 0, - Name: "create_migrations_table", - Apply: `CREATE TABLE IF NOT EXISTS migrations ( - date timestamp NOT NULL, - version bigint NOT NULL, - name varchar(512) NOT NULL, - PRIMARY KEY (date,version) - )`, - Discard: `DROP TABLE IF EXISTS migrations CASCADE`, - } -) diff --git a/migrate/migrate_down_test.go b/migrate/migrate_down_test.go index 0ff0fa0..ad3c29b 100644 --- a/migrate/migrate_down_test.go +++ b/migrate/migrate_down_test.go @@ -28,7 +28,7 @@ func TestMigrationDown(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration4.Version, time.Now(), migration4.Name), ) - mock.ExpectExec(migration4.Discard).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration4.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(3, NOW(), 'roles_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -38,7 +38,7 @@ func TestMigrationDown(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration3.Version, time.Now(), migration3.Name), ) - mock.ExpectExec(migration3.Discard).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration3.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(2, NOW(), 'users_email_index')`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -48,7 +48,7 @@ func TestMigrationDown(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration2.Version, time.Now(), migration2.Name), ) - mock.ExpectExec(migration2.Discard).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration2.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(1, NOW(), 'users_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -58,7 +58,7 @@ func TestMigrationDown(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration1.Version, time.Now(), migration1.Name), ) - mock.ExpectExec(migration1.Discard).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration1.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(0, NOW(), 'create_migrations_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -68,7 +68,7 @@ func TestMigrationDown(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration0.Version, time.Now(), migration0.Name), ) - mock.ExpectExec(migration0.Discard).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration0.Discard.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() m, err := New(mdb, StdLog, migrations) diff --git a/migrate/migrate_up_test.go b/migrate/migrate_up_test.go index fad5353..56edbb8 100644 --- a/migrate/migrate_up_test.go +++ b/migrate/migrate_up_test.go @@ -26,7 +26,7 @@ func TestMigrationUp(t *testing.T) { mock.ExpectQuery(versionQuery).WillReturnError(fmt.Errorf("relation does not exist")) mock.ExpectRollback() mock.ExpectBegin() - mock.ExpectExec(migration0.Apply).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(migration0.Apply.Statements[0]).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(0, NOW(), 'create_migrations_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectCommit() @@ -37,7 +37,7 @@ func TestMigrationUp(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration0.Version, time.Now(), migration0.Name), ) - mock.ExpectExec(migration1.Apply). + mock.ExpectExec(migration1.Apply.Statements[0]). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(1, NOW(), 'users_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) @@ -49,7 +49,7 @@ func TestMigrationUp(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration1.Version, time.Now(), migration1.Name), ) - mock.ExpectExec(migration2.Apply). + mock.ExpectExec(migration2.Apply.Statements[0]). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(2, NOW(), 'users_email_index')`). WillReturnResult(sqlmock.NewResult(0, 1)) @@ -61,7 +61,7 @@ func TestMigrationUp(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration2.Version, time.Now(), migration2.Name), ) - mock.ExpectExec(migration3.Apply). + mock.ExpectExec(migration3.Apply.Statements[0]). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(3, NOW(), 'roles_table')`). WillReturnResult(sqlmock.NewResult(0, 1)) @@ -73,7 +73,7 @@ func TestMigrationUp(t *testing.T) { sqlmock.NewRows([]string{"date", "version", "name"}). AddRow(migration3.Version, time.Now(), migration3.Name), ) - mock.ExpectExec(migration4.Apply). + mock.ExpectExec(migration4.Apply.Statements[0]). WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec(`INSERT INTO migrations(version, date, name) values(4, NOW(), 'user_roles_fk')`). WillReturnResult(sqlmock.NewResult(0, 1)) @@ -99,25 +99,42 @@ var ( migration1 = &Migration{ Version: 1, Name: "users_table", - Apply: "CREATE TABLE IF NOT EXISTS users(id text, name text, email text, role text, PRIMARY KEY (id))", - Discard: "DROP TABLE IF EXISTS users CASCADE", + Apply: Statements{ + NoTx: true, + Statements: []string{"CREATE TABLE IF NOT EXISTS users(id text, name text, email text, role text, PRIMARY KEY (id))"}, + }, + Discard: Statements{ + Statements: []string{"DROP TABLE IF EXISTS users CASCADE"}, + }, } migration2 = &Migration{ Version: 2, Name: "users_email_index", - Apply: "CREATE INDEX IF NOT EXISTS ix_users_email ON users (email)", - Discard: "DROP INDEX IF EXISTS ix_users_email CASCADE", + Apply: Statements{ + Statements: []string{"CREATE INDEX IF NOT EXISTS ix_users_email ON users (email)"}, + }, + Discard: Statements{ + Statements: []string{"DROP INDEX IF EXISTS ix_users_email CASCADE"}, + }, } migration3 = &Migration{ Version: 3, Name: "roles_table", - Apply: "CREATE TABLE IF NOT EXISTS roles(id text, name text, properties jsonb NOT NULL DEFAULT '{}'::jsonb, PRIMARY KEY (id))", - Discard: "DROP TABLE IF EXISTS roles CASCADE", + Apply: Statements{ + Statements: []string{"CREATE TABLE IF NOT EXISTS roles(id text, name text, properties jsonb NOT NULL DEFAULT '{}'::jsonb, PRIMARY KEY (id))"}, + }, + Discard: Statements{ + Statements: []string{"DROP TABLE IF EXISTS roles CASCADE"}, + }, } migration4 = &Migration{ Version: 4, Name: "user_roles_fk", - Apply: "ALTER TABLE users ADD CONSTRAINT roles_fk FOREIGN KEY (role) REFERENCES roles (id)", - Discard: "ALTER TABLE users DROP CONSTRAINT roles_fk CASCADE", + Apply: Statements{ + Statements: []string{"ALTER TABLE users ADD CONSTRAINT roles_fk FOREIGN KEY (role) REFERENCES roles (id)"}, + }, + Discard: Statements{ + Statements: []string{"ALTER TABLE users DROP CONSTRAINT roles_fk CASCADE"}, + }, } ) diff --git a/migrate/parse.go b/migrate/parse.go new file mode 100644 index 0000000..9868a58 --- /dev/null +++ b/migrate/parse.go @@ -0,0 +1,45 @@ +package migrate + +import ( + "bufio" + "bytes" + "fmt" + "regexp" + "strings" +) + +var ( + noTXRegexp = regexp.MustCompile(`--\s+migrate:\s+NoTransaction`) +) + +func parseStatement(data []byte) (s Statements, err error) { + s = Statements{} + + var stmt string + scanner := bufio.NewScanner(bytes.NewReader(data)) + for scanner.Scan() { + line := scanner.Text() + + if strings.HasPrefix(line, "--") { + if noTXRegexp.MatchString(line) { + s.NoTx = true + } + continue + } + + if line[len(line)-1] == ';' { + if stmt != "" { + stmt += " " + } + stmt += line[:len(line)-1] + s.Statements = append(s.Statements, stmt) + stmt = "" + } + } + + if s.NoTx && len(s.Statements) > 1 { + return s, fmt.Errorf("migrate: migrations that disable transactions must have only one statement") + } + + return s, nil +} From ab162eaa1bac0ceffc3ae1bd3a8c60591066b975 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Tue, 12 Oct 2021 15:25:18 +0100 Subject: [PATCH 2/5] migrate: log statements --- migrate/migrate.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/migrate/migrate.go b/migrate/migrate.go index 52e6cee..887289e 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -30,12 +30,7 @@ var ( Apply: Statements{ NoTx: false, Statements: []string{ - `CREATE TABLE IF NOT EXISTS migrations ( - date timestamp NOT NULL, - version bigint NOT NULL, - name varchar(512) NOT NULL, - PRIMARY KEY (date,version) - )`}, + `CREATE TABLE IF NOT EXISTS migrations (date timestamp NOT NULL, version bigint NOT NULL, name varchar(512) NOT NULL, PRIMARY KEY (date,version))`}, }, Discard: Statements{ NoTx: false, @@ -323,6 +318,8 @@ func (m *Migrate) apply(ctx context.Context, mig *Migration, discard bool) (err } for x := 0; x < len(statements.Statements); x++ { + m.logger("migrate: %s, discard: %t, transaction: %t, statement: %s", mig.Name, discard, !statements.NoTx, statements.Statements[x]) + switch statements.NoTx { case false: if _, err := tx.ExecContext(ctx, statements.Statements[x]); err != nil { From 9346f56d03b9af113bed5985f9b78fbd0d1c8f17 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Tue, 12 Oct 2021 15:36:41 +0100 Subject: [PATCH 3/5] migrate: fix multiline statement parsing and add tests --- migrate/parse.go | 21 +++++++++++++++++---- migrate/parse_test.go | 44 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 migrate/parse_test.go diff --git a/migrate/parse.go b/migrate/parse.go index 9868a58..10fad12 100644 --- a/migrate/parse.go +++ b/migrate/parse.go @@ -17,8 +17,13 @@ func parseStatement(data []byte) (s Statements, err error) { var stmt string scanner := bufio.NewScanner(bytes.NewReader(data)) + for scanner.Scan() { - line := scanner.Text() + line := strings.TrimSpace(scanner.Text()) + + if len(line) == 0 { + continue + } if strings.HasPrefix(line, "--") { if noTXRegexp.MatchString(line) { @@ -27,14 +32,22 @@ func parseStatement(data []byte) (s Statements, err error) { continue } + if stmt != "" { + stmt += " " + } + if line[len(line)-1] == ';' { - if stmt != "" { - stmt += " " - } stmt += line[:len(line)-1] s.Statements = append(s.Statements, stmt) stmt = "" + continue } + + stmt += line + } + + if stmt != "" { + s.Statements = append(s.Statements, stmt) } if s.NoTx && len(s.Statements) > 1 { diff --git a/migrate/parse_test.go b/migrate/parse_test.go new file mode 100644 index 0000000..211e359 --- /dev/null +++ b/migrate/parse_test.go @@ -0,0 +1,44 @@ +package migrate + +import ( + "reflect" + "testing" +) + +func TestParseSimple(t *testing.T) { + stmt, err := parseStatement(statement) + if err != nil { + t.Fatalf("failed to parse statement: %s", err) + } + + if !reflect.DeepEqual(expected, stmt) { + t.Fatalf("expected: %#v got: %#v", expected, stmt) + } +} + +var statement = []byte(` +CREATE TABLE IF NOT EXISTS users ( + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + id UUID, + name text NOT NULL, + email text NOT NULL, + PRIMARY KEY (id) +); +CREATE UNIQUE INDEX IF NOT EXISTS ix_unique_users_name ON users (name); +CREATE UNIQUE INDEX IF NOT EXISTS ix_unique_users_email ON users (email); +CREATE INDEX IF NOT EXISTS ix_users_created_at ON users (created_at); +CREATE INDEX IF NOT EXISTS ix_users_updated_at ON users (updated_at); + +`) + +var expected = Statements{ + NoTx: false, + Statements: []string{ + "CREATE TABLE IF NOT EXISTS users ( created_at timestamptz NOT NULL DEFAULT now(), updated_at timestamptz NOT NULL DEFAULT now(), id UUID, name text NOT NULL, email text NOT NULL, PRIMARY KEY (id) )", + "CREATE UNIQUE INDEX IF NOT EXISTS ix_unique_users_name ON users (name)", + "CREATE UNIQUE INDEX IF NOT EXISTS ix_unique_users_email ON users (email)", + "CREATE INDEX IF NOT EXISTS ix_users_created_at ON users (created_at)", + "CREATE INDEX IF NOT EXISTS ix_users_updated_at ON users (updated_at)", + }, +} From 341f9d2301e71245d79c2331baf2282fc553ae8f Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Sun, 17 Oct 2021 08:55:19 +0100 Subject: [PATCH 4/5] migrate: add test for invalid transaction --- migrate/parse.go | 5 +++-- migrate/parse_test.go | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/migrate/parse.go b/migrate/parse.go index 10fad12..8ea78fd 100644 --- a/migrate/parse.go +++ b/migrate/parse.go @@ -9,7 +9,8 @@ import ( ) var ( - noTXRegexp = regexp.MustCompile(`--\s+migrate:\s+NoTransaction`) + ErrInvalidNoTx = fmt.Errorf("migrate: migrations that disable transactions must have only one statement") + noTXRegexp = regexp.MustCompile(`--\s+migrate:\s+NoTransaction`) ) func parseStatement(data []byte) (s Statements, err error) { @@ -51,7 +52,7 @@ func parseStatement(data []byte) (s Statements, err error) { } if s.NoTx && len(s.Statements) > 1 { - return s, fmt.Errorf("migrate: migrations that disable transactions must have only one statement") + return s, ErrInvalidNoTx } return s, nil diff --git a/migrate/parse_test.go b/migrate/parse_test.go index 211e359..7e63ab0 100644 --- a/migrate/parse_test.go +++ b/migrate/parse_test.go @@ -16,6 +16,15 @@ func TestParseSimple(t *testing.T) { } } +func TestParseMultiNoTx(t *testing.T) { + notx := append([]byte(`-- migrate: NoTransaction`), statement...) + _, err := parseStatement(notx) + + if err != ErrInvalidNoTx { + t.Fatalf("failed to parse statement: %s", err) + } +} + var statement = []byte(` CREATE TABLE IF NOT EXISTS users ( created_at timestamptz NOT NULL DEFAULT now(), From be13cf7d1e6a50989fedde4abb9bc6716d25ea74 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Sun, 17 Oct 2021 08:58:41 +0100 Subject: [PATCH 5/5] migrate: fix string lenght check --- migrate/parse.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrate/parse.go b/migrate/parse.go index 8ea78fd..39f235a 100644 --- a/migrate/parse.go +++ b/migrate/parse.go @@ -22,7 +22,7 @@ func parseStatement(data []byte) (s Statements, err error) { for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) - if len(line) == 0 { + if line == "" { continue }