Skip to content

Commit

Permalink
Add a CaptureOutput function to trap stdout and stderr
Browse files Browse the repository at this point in the history
  • Loading branch information
FollowTheProcess committed Jan 18, 2024
1 parent e2333f6 commit d2584d6
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
1 change: 0 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ linters:
- nilerr
- nilnil
- nolintlint
- nonamedreturns
- predeclared
- reassign
- revive
Expand Down
78 changes: 78 additions & 0 deletions test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ package test
import (
"bytes"
"fmt"
"io"
"math"
"os"
"path/filepath"
"reflect"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -213,3 +215,79 @@ func File(t testing.TB, file, want string) {

Diff(t, string(contents), want)
}

// CaptureOutput captures and returns data printed to stdout and stderr by the provided function fn.
func CaptureOutput(t testing.TB, fn func() error) (stdout, stderr string) {
t.Helper()

// Take copies of the original streams
oldStdout := os.Stdout
oldStderr := os.Stderr

defer func() {
// Restore everything back to normal
os.Stdout = oldStdout
os.Stderr = oldStderr
}()

stdoutReader, stdoutWriter, err := os.Pipe()
if err != nil {
t.Fatalf("CaptureOutput: could not construct an os.Pipe(): %v", err)
}

Check warning on line 236 in test.go

View check run for this annotation

Codecov / codecov/patch

test.go#L235-L236

Added lines #L235 - L236 were not covered by tests

stderrReader, stderrWriter, err := os.Pipe()
if err != nil {
t.Fatalf("CaptureOutput: could not construct an os.Pipe(): %v", err)
}

Check warning on line 241 in test.go

View check run for this annotation

Codecov / codecov/patch

test.go#L240-L241

Added lines #L240 - L241 were not covered by tests

// Set stdout and stderr streams to the pipe writers
os.Stdout = stdoutWriter
os.Stderr = stderrWriter

stdoutCapture := make(chan string)
stderrCapture := make(chan string)

var wg sync.WaitGroup
wg.Add(2) //nolint: gomnd

// Copy in goroutines to avoid blocking
go func(wg *sync.WaitGroup) {
defer func() {
close(stdoutCapture)
wg.Done()
}()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, stdoutReader); err != nil {
t.Fatalf("CaptureOutput: failed to copy from stdout reader: %v", err)
}

Check warning on line 262 in test.go

View check run for this annotation

Codecov / codecov/patch

test.go#L261-L262

Added lines #L261 - L262 were not covered by tests
stdoutCapture <- buf.String()
}(&wg)

go func(wg *sync.WaitGroup) {
defer func() {
close(stderrCapture)
wg.Done()
}()
buf := &bytes.Buffer{}
if _, err := io.Copy(buf, stderrReader); err != nil {
t.Fatalf("CaptureOutput: failed to copy from stderr reader: %v", err)
}

Check warning on line 274 in test.go

View check run for this annotation

Codecov / codecov/patch

test.go#L273-L274

Added lines #L273 - L274 were not covered by tests
stderrCapture <- buf.String()
}(&wg)

// Call the test function that produces the output
if err := fn(); err != nil {
t.Fatalf("CaptureOutput: user function returned an error: %v", err)
}

// Close the writers
stdoutWriter.Close()
stderrWriter.Close()

capturedStdout := <-stdoutCapture
capturedStderr := <-stderrCapture

wg.Wait()

return capturedStdout, capturedStderr
}
36 changes: 36 additions & 0 deletions test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,42 @@ func TestData(t *testing.T) {
}
}

func TestCapture(t *testing.T) {
t.Run("happy", func(t *testing.T) {
// Some fake user function that writes to stdout and stderr
fn := func() error {
fmt.Fprintln(os.Stdout, "hello stdout")
fmt.Fprintln(os.Stderr, "hello stderr")

return nil
}

stdout, stderr := test.CaptureOutput(t, fn)

test.Equal(t, stdout, "hello stdout\n")
test.Equal(t, stderr, "hello stderr\n")
})

t.Run("sad", func(t *testing.T) {
// This time the user function returns an error
fn := func() error {
return errors.New("it broke")
}

buf := &bytes.Buffer{}
testTB := &TB{out: buf}

stdout, stderr := test.CaptureOutput(testTB, fn)

// Test should have failed
test.True(t, testTB.failed)

// stdout and stderr should be empty
test.Equal(t, stdout, "")
test.Equal(t, stderr, "")
})
}

// Always returns a nil error, needed because manually constructing
// nil means it's not an error type but here it is.
func nilErr() error {
Expand Down

0 comments on commit d2584d6

Please sign in to comment.