Skip to content

Commit

Permalink
Add an optional context for Ok and Err (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
FollowTheProcess authored Oct 28, 2023
1 parent d75b2bc commit 0064e31
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
48 changes: 42 additions & 6 deletions test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
package test

import (
"fmt"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
)

// Equal fails if got != want.
//
// test.Equal(t, "apples", "apples") // Passes
// test.Equal(t, "apples", "oranges") // Fails
func Equal[T comparable](t testing.TB, got, want T) {
t.Helper()
if got != want {
Expand All @@ -29,6 +33,9 @@ func EqualFunc[T any](t testing.TB, got, want T, equal func(a, b T) bool) {
}

// NotEqual fails if got == want.
//
// test.NotEqual(t, "apples", "oranges") // Passes
// test.NotEqual(t, "apples", "apples") // Fails
func NotEqual[T comparable](t testing.TB, got, want T) {
t.Helper()
if got == want {
Expand All @@ -47,27 +54,50 @@ func NotEqualFunc[T any](t testing.TB, got, want T, equal func(a, b T) bool) {
}
}

// Ok fails if err != nil.
func Ok(t testing.TB, err error) {
// Ok fails if err != nil, optionally adding context to the output.
//
// err := doSomething()
// test.Ok(t, err, "Doing something")
func Ok(t testing.TB, err error, context ...string) {
t.Helper()
var msg string
if len(context) == 0 {
msg = fmt.Sprintf("\nGot error:\t%v\nWanted:\tnil\n", err)
} else {
msg = fmt.Sprintf("\nGot error:\t%v\nWanted:\tnil\nContext:\t%s\n", err, context[0])
}
if err != nil {
t.Fatalf("\nGot error:\t%v\nWanted:\tnil\n", err)
t.Fatalf(msg, err)
}
}

// Err fails if err == nil.
func Err(t testing.TB, err error) {
//
// err := shouldReturnErr()
// test.Err(t, err, "shouldReturnErr")
func Err(t testing.TB, err error, context ...string) {
t.Helper()
var msg string
if len(context) == 0 {
msg = fmt.Sprintf("Error was not nil:\t%v\n", err)
} else {
msg = fmt.Sprintf("Error was not nil:\t%v\nContext:\t%s", err, context[0])
}
if err == nil {
t.Fatalf("Error was not nil:\t%v\n", err)
t.Fatalf(msg, err)
}
}

// ErrIsWanted fails if you got an error and didn't want it, or if you
// didn't get an error but wanted one.
//
// It simplifies checking for errors in table driven tests where on any
// iteration err could either be nil or not.
// iteration err may or may not be nil.
//
// test.ErrIsWanted(t, errors.New("uh oh"), true) // Passes, got error when we wanted one
// test.ErrIsWanted(t, errors.New("uh oh"), false) // Fails, got error but didn't want one
// test.ErrIsWanted(t, nil, true) // Fails, wanted an error but didn't get one
// test.ErrIsWanted(t, nil, false) // Passes, didn't want an error and didn't get one
func ErrIsWanted(t testing.TB, err error, want bool) {
t.Helper()
if (err != nil) != want {
Expand All @@ -76,6 +106,9 @@ func ErrIsWanted(t testing.TB, err error, want bool) {
}

// True fails if v is false.
//
// test.True(t, true) // Passes
// test.True(t, false) // Fails
func True(t testing.TB, v bool) {
t.Helper()
if !v {
Expand All @@ -84,6 +117,9 @@ func True(t testing.TB, v bool) {
}

// False fails if v is true.
//
// test.False(t, false) // Passes
// test.False(t, true) // Fails
func False(t testing.TB, v bool) {
t.Helper()
if v {
Expand Down
4 changes: 4 additions & 0 deletions test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ func TestPass(t *testing.T) {
func(tb testing.TB) { test.NotEqual(tb, true, false) },
func(tb testing.TB) { test.NotEqual(tb, 3.14, 8.67) },
func(tb testing.TB) { test.Ok(tb, nil) },
func(tb testing.TB) { test.Ok(tb, nil, "Something") },
func(tb testing.TB) { test.Err(tb, errors.New("uh oh")) },
func(tb testing.TB) { test.Err(tb, errors.New("uh oh"), "Something") },
func(tb testing.TB) { test.True(tb, true) },
func(tb testing.TB) { test.False(tb, false) },
func(tb testing.TB) { test.Diff(tb, 42, 42) },
Expand Down Expand Up @@ -128,7 +130,9 @@ func TestFail(t *testing.T) {
func(tb testing.TB) { test.NotEqual(tb, true, true) },
func(tb testing.TB) { test.NotEqual(tb, 3.14, 3.14) },
func(tb testing.TB) { test.Ok(tb, errors.New("uh oh")) },
func(tb testing.TB) { test.Ok(tb, errors.New("uh oh"), "Something") },
func(tb testing.TB) { test.Err(tb, nilErr()) },
func(tb testing.TB) { test.Err(tb, nilErr(), "Something") },
func(tb testing.TB) { test.True(tb, false) },
func(tb testing.TB) { test.False(tb, true) },
func(tb testing.TB) { test.Diff(tb, "hello", "there") },
Expand Down

0 comments on commit 0064e31

Please sign in to comment.