diff --git a/atlasexec/atlas.go b/atlasexec/atlas.go index ec27380..92ff5d9 100644 --- a/atlasexec/atlas.go +++ b/atlasexec/atlas.go @@ -189,6 +189,8 @@ func NewClient(workingDir, execPath string) (_ *Client, err error) { // return err // }) func (c *Client) WithWorkDir(dir string, fn func(*Client) error) error { + wd := c.workingDir + defer func() { c.workingDir = wd }() c.workingDir = dir return fn(c) } diff --git a/atlasexec/atlas_models.go b/atlasexec/atlas_models.go index c1d6ec8..bbaac14 100644 --- a/atlasexec/atlas_models.go +++ b/atlasexec/atlas_models.go @@ -1,6 +1,8 @@ package atlasexec import ( + "fmt" + "strings" "time" "ariga.io/atlas/sql/sqlcheck" @@ -19,8 +21,9 @@ type ( File Start time.Time End time.Time - Skipped int // Amount of skipped SQL statements in a partially applied file. - Applied []string // SQL statements applied with success + Skipped int // Amount of skipped SQL statements in a partially applied file. + Applied []string // SQL statements applied with success + Checks []*FileChecks // Assertion checks Error *struct { Stmt string // SQL statement that failed. Text string // Error returned by the database. @@ -120,6 +123,19 @@ type ( Error string `json:"Error,omitempty"` // File specific error. } + // FileChecks represents a set of checks to run before applying a file. + FileChecks struct { + Name string `json:"Name,omitempty"` // File/group name. + Stmts []*Check `json:"Stmts,omitempty"` // Checks statements executed. + Error *StmtError `json:"Error,omitempty"` // Assertion error. + Start time.Time `json:"Start,omitempty"` // Start assertion time. + End time.Time `json:"End,omitempty"` // End assertion time. + } + // Check represents an assertion and its status. + Check struct { + Stmt string `json:"Stmt,omitempty"` // Assertion statement. + Error *string `json:"Error,omitempty"` // Assertion error, if any. + } // StmtError groups a statement with its execution error. StmtError struct { Stmt string `json:"Stmt,omitempty"` // SQL statement that failed. @@ -179,6 +195,87 @@ type ( } ) +// Summary of the migration attempt. +func (a *MigrateApply) Summary(ident string) string { + var ( + passedC, failedC int + passedS, failedS int + passedF, failedF int + lines = make([]string, 0, 3) + ) + for _, f := range a.Applied { + // For each check file, count the + // number of failed assertions. + for _, cf := range f.Checks { + for _, s := range cf.Stmts { + if s.Error != nil { + failedC++ + } else { + passedC++ + } + } + } + passedS += len(f.Applied) + if f.Error != nil { + failedF++ + // Last statement failed (not an assertion). + if len(f.Checks) == 0 || f.Checks[len(f.Checks)-1].Error == nil { + passedS-- + failedS++ + } + } else { + passedF++ + } + } + // Execution time. + lines = append(lines, a.End.Sub(a.Start).String()) + // Executed files. + switch { + case passedF > 0 && failedF > 0: + lines = append(lines, fmt.Sprintf("%d migration%s ok, %d with errors", passedF, plural(passedF), failedF)) + case passedF > 0: + lines = append(lines, fmt.Sprintf("%d migration%s", passedF, plural(passedF))) + case failedF > 0: + lines = append(lines, fmt.Sprintf("%d migration%s with errors", failedF, plural(failedF))) + } + // Executed checks. + switch { + case passedC > 0 && failedC > 0: + lines = append(lines, fmt.Sprintf("%d check%s ok, %d failure%s", passedC, plural(passedC), failedC, plural(failedC))) + case passedC > 0: + lines = append(lines, fmt.Sprintf("%d check%s", passedC, plural(passedC))) + case failedC > 0: + lines = append(lines, fmt.Sprintf("%d check error%s", failedC, plural(failedC))) + } + // Executed statements. + switch { + case passedS > 0 && failedS > 0: + lines = append(lines, fmt.Sprintf("%d sql statement%s ok, %d with errors", passedS, plural(passedS), failedS)) + case passedS > 0: + lines = append(lines, fmt.Sprintf("%d sql statement%s", passedS, plural(passedS))) + case failedS > 0: + lines = append(lines, fmt.Sprintf("%d sql statement%s with errors", failedS, plural(failedS))) + } + var b strings.Builder + for i, l := range lines { + b.WriteString("-") + b.WriteByte(' ') + b.WriteString(fmt.Sprintf("**%s**", l)) + if i < len(lines)-1 { + b.WriteByte('\n') + b.WriteString(ident) + } + } + return b.String() +} + +func plural(n int) (s string) { + if n > 1 { + s += "s" + } + return +} + // Error implements the error interface. func (e *MigrateApplyError) Error() string { return last(e.Result).Error } diff --git a/atlasexec/atlas_test.go b/atlasexec/atlas_test.go index 192e358..0f3c508 100644 --- a/atlasexec/atlas_test.go +++ b/atlasexec/atlas_test.go @@ -86,7 +86,10 @@ func Test_MigrateApply(t *testing.T) { Env: "test", }) require.NoError(t, err) - require.EqualValues(t, "20230926085734", got.Target) + require.Equal(t, "sqlite3", got.Env.Driver) + require.Equal(t, "migrations", got.Env.Dir) + require.Equal(t, "sqlite://file?_fk=1&cache=shared&mode=memory", got.Env.URL.String()) + require.Equal(t, "20230926085734", got.Target) // Add dirty changes and try again os.Setenv("DB_URL", "sqlite://test.db?_fk=1&cache=shared&mode=memory") drv, err := sql.Open("sqlite3", "test.db") diff --git a/atlasexec/working_dir_test.go b/atlasexec/working_dir_test.go index 9dcfddf..6a79522 100644 --- a/atlasexec/working_dir_test.go +++ b/atlasexec/working_dir_test.go @@ -79,3 +79,15 @@ func TestContextExecer(t *testing.T) { require.Equal(t, "atlas.hcl\nmigrations\n", buf.String()) require.NoError(t, ce.Close()) } + +func TestMaintainOriginalWorkingDir(t *testing.T) { + dir := t.TempDir() + c, err := NewClient(dir, "atlas") + require.NoError(t, err) + require.Equal(t, dir, c.workingDir) + require.NoError(t, c.WithWorkDir("bar", func(c *Client) error { + require.Equal(t, "bar", c.workingDir) + return nil + })) + require.Equal(t, dir, c.workingDir, "The working directory should not be changed") +}