diff --git a/.golangci.yml b/.golangci.yml index f45fdd7..6365f55 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -23,7 +23,6 @@ linters: - nilerr - nilnil - nolintlint - - nonamedreturns - predeclared - reassign - revive diff --git a/test.go b/test.go index 157b9d3..b4f46c9 100644 --- a/test.go +++ b/test.go @@ -5,10 +5,12 @@ package test import ( "bytes" "fmt" + "io" "math" "os" "path/filepath" "reflect" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -213,3 +215,79 @@ func File(t testing.TB, file, want string) { Diff(t, string(contents), want) } + +// CaptureOutput captures and returns data printed to stdout and stderr by the provided function fn. +func CaptureOutput(t testing.TB, fn func() error) (stdout, stderr string) { + t.Helper() + + // Take copies of the original streams + oldStdout := os.Stdout + oldStderr := os.Stderr + + defer func() { + // Restore everything back to normal + os.Stdout = oldStdout + os.Stderr = oldStderr + }() + + stdoutReader, stdoutWriter, err := os.Pipe() + if err != nil { + t.Fatalf("CaptureOutput: could not construct an os.Pipe(): %v", err) + } + + stderrReader, stderrWriter, err := os.Pipe() + if err != nil { + t.Fatalf("CaptureOutput: could not construct an os.Pipe(): %v", err) + } + + // Set stdout and stderr streams to the pipe writers + os.Stdout = stdoutWriter + os.Stderr = stderrWriter + + stdoutCapture := make(chan string) + stderrCapture := make(chan string) + + var wg sync.WaitGroup + wg.Add(2) //nolint: gomnd + + // Copy in goroutines to avoid blocking + go func(wg *sync.WaitGroup) { + defer func() { + close(stdoutCapture) + wg.Done() + }() + buf := &bytes.Buffer{} + if _, err := io.Copy(buf, stdoutReader); err != nil { + t.Fatalf("CaptureOutput: failed to copy from stdout reader: %v", err) + } + stdoutCapture <- buf.String() + }(&wg) + + go func(wg *sync.WaitGroup) { + defer func() { + close(stderrCapture) + wg.Done() + }() + buf := &bytes.Buffer{} + if _, err := io.Copy(buf, stderrReader); err != nil { + t.Fatalf("CaptureOutput: failed to copy from stderr reader: %v", err) + } + stderrCapture <- buf.String() + }(&wg) + + // Call the test function that produces the output + if err := fn(); err != nil { + t.Fatalf("CaptureOutput: user function returned an error: %v", err) + } + + // Close the writers + stdoutWriter.Close() + stderrWriter.Close() + + capturedStdout := <-stdoutCapture + capturedStderr := <-stderrCapture + + wg.Wait() + + return capturedStdout, capturedStderr +} diff --git a/test_test.go b/test_test.go index 7ab3549..63ca6cb 100644 --- a/test_test.go +++ b/test_test.go @@ -191,6 +191,42 @@ func TestData(t *testing.T) { } } +func TestCapture(t *testing.T) { + t.Run("happy", func(t *testing.T) { + // Some fake user function that writes to stdout and stderr + fn := func() error { + fmt.Fprintln(os.Stdout, "hello stdout") + fmt.Fprintln(os.Stderr, "hello stderr") + + return nil + } + + stdout, stderr := test.CaptureOutput(t, fn) + + test.Equal(t, stdout, "hello stdout\n") + test.Equal(t, stderr, "hello stderr\n") + }) + + t.Run("sad", func(t *testing.T) { + // This time the user function returns an error + fn := func() error { + return errors.New("it broke") + } + + buf := &bytes.Buffer{} + testTB := &TB{out: buf} + + stdout, stderr := test.CaptureOutput(testTB, fn) + + // Test should have failed + test.True(t, testTB.failed) + + // stdout and stderr should be empty + test.Equal(t, stdout, "") + test.Equal(t, stderr, "") + }) +} + // Always returns a nil error, needed because manually constructing // nil means it's not an error type but here it is. func nilErr() error {