diff --git a/benchmark/bench_test.go b/benchmark/bench_test.go index c733688..5192cfa 100644 --- a/benchmark/bench_test.go +++ b/benchmark/bench_test.go @@ -1,8 +1,9 @@ package main import ( - "github.com/garlicnation/promises" "testing" + + promise "github.com/garlicnation/promises/v2" ) var values []int diff --git a/blog_example/promises_checksum/promises_checksum.go b/blog_example/promises_checksum/promises_checksum.go index edd1853..37ae4d2 100644 --- a/blog_example/promises_checksum/promises_checksum.go +++ b/blog_example/promises_checksum/promises_checksum.go @@ -4,11 +4,12 @@ import ( "crypto/sha512" "encoding/hex" "fmt" - "github.com/garlicnation/promises" "io" "io/ioutil" "net/http" "time" + + promise "github.com/garlicnation/promises/v2" ) var listOfWebsites = []string{ diff --git a/docs.go b/docs.go index d554f79..89a14c0 100644 --- a/docs.go +++ b/docs.go @@ -1,7 +1,7 @@ /* -Promises is a library that builds something similar to JS style promises, or Futures(as seen in Java and other languages) for golang. +Package promise builds something similar to JS style promises, or Futures(as seen in Java and other languages) for golang. -Promises is type-safe at runtime, and within an order of magnitude of performance of a solution built with pure channels and goroutines. +promise is type-safe at runtime, and within an order of magnitude of performance of a solution built with pure channels and goroutines. For a more thorough introduction to the library, please check out https://github.com/garlicnation/promises/blob/master/blog_example/WHY.md diff --git a/go.mod b/go.mod index 664080b..da1c1df 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,8 @@ -module github.com/garlicnation/promises +module github.com/garlicnation/promises/v2 go 1.12 require ( - github.com/campoy/embedmd v1.0.0 // indirect github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.4.0 ) diff --git a/go.sum b/go.sum index dce0524..dee7d98 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/campoy/embedmd v1.0.0 h1:V4kI2qTJJLf4J29RzI/MAt2c3Bl4dQSYPuflzwFH2hY= -github.com/campoy/embedmd v1.0.0/go.mod h1:oxyr9RCiSXg0M3VJ3ks0UGfp98BpSSGr0kpiX3MzVl8= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= diff --git a/promises.go b/promises.go index 1e09601..5f1dd46 100644 --- a/promises.go +++ b/promises.go @@ -1,17 +1,21 @@ package promise -import "reflect" -import "sync" -import "sync/atomic" -import "github.com/pkg/errors" +import ( + "fmt" + "reflect" + "sync" + "sync/atomic" + + "github.com/pkg/errors" +) type promiseType int const ( - legacyCall promiseType = iota - simpleCall + simpleCall promiseType = iota thenCall allCall + raceCall anyCall ) @@ -23,10 +27,12 @@ type Promise struct { functionRv reflect.Value results []reflect.Value resultType []reflect.Type + anyErrs []error // returnsError is true if the last value returns an error returnsError bool cond sync.Cond counter int64 + errCounter int64 noCopy } @@ -36,7 +42,7 @@ type noCopy struct{} func (*noCopy) Lock() {} func (*noCopy) Unlock() {} -func (p *Promise) anyCall(priors []*Promise, index int) (results []reflect.Value) { +func (p *Promise) raceCall(priors []*Promise, index int) (results []reflect.Value) { prior := priors[index] prior.cond.L.Lock() for !prior.complete { @@ -78,6 +84,40 @@ func (p *Promise) allCall(priors []*Promise, index int) (results []reflect.Value return nil } +// AnyErr returns when all promises passed to Any fail +type AnyErr struct { + // Errs contains the error of all passed promises + Errs []error + // LastErr contains the error of the last promise to fail. + LastErr error +} + +func (err *AnyErr) Error() string { + return fmt.Sprintf("all %d promises failed. last err=%v", len(err.Errs), err.LastErr) +} + +func (p *Promise) anyCall(priors []*Promise, index int) (results []reflect.Value) { + prior := priors[index] + prior.cond.L.Lock() + for !prior.complete { + prior.cond.Wait() + } + prior.cond.L.Unlock() + if prior.err != nil { + remaining := atomic.AddInt64(&p.errCounter, -1) + p.anyErrs[index] = prior.err + if remaining != 0 { + return nil + } + panic(AnyErr{Errs: p.anyErrs[:], LastErr: prior.err}) + } + remaining := atomic.AddInt64(&p.counter, -1) + if remaining == 0 { + return prior.results[:] + } + return nil +} + func empty() {} // All returns a promise that resolves if all of the passed promises @@ -107,10 +147,10 @@ func All(promises ...*Promise) *Promise { const anyErrorFormat = "promise %d has an unexpected return type, expected all promises passed to Any to return the same type" -// Any returns a promise that resolves if any of the passed promises +// Race returns a promise that resolves if any of the passed promises // succeed or fails if any of the passed promises panics. // All of the supplied promises must be of the same type. -func Any(promises ...*Promise) *Promise { +func Race(promises ...*Promise) *Promise { if len(promises) == 0 { return New(empty) } @@ -135,16 +175,57 @@ func Any(promises ...*Promise) *Promise { p := &Promise{ cond: sync.Cond{L: &sync.Mutex{}}, - t: anyCall, + t: raceCall, } // Extract the type - p.resultType = []reflect.Type{} - for _, prior := range promises { - p.resultType = append(p.resultType, prior.resultType...) + p.resultType = firstResultType[:] + + p.counter = int64(1) + + for i := range promises { + go p.run(reflect.Value{}, nil, promises, i, nil) + } + return p +} + +// Any returns a promise that resolves if any of the passed promises +// succeed or fails if all of the passed promises panics. +// All of the supplied promises must be of the same type. +func Any(promises ...*Promise) *Promise { + if len(promises) == 0 { + return New(empty) + } + + if len(promises) == 1 { + return promises[0] + } + + // Check that all the promises have the same return type + firstResultType := promises[0].resultType + for promiseIdx, promise := range promises[1:] { + newResultType := promise.resultType + if len(firstResultType) != len(newResultType) { + panic(errors.Errorf(anyErrorFormat, promiseIdx)) + } + for index := range firstResultType { + if firstResultType[index] != newResultType[index] { + panic(errors.Errorf(anyErrorFormat, promiseIdx)) + } + } + } + + p := &Promise{ + cond: sync.Cond{L: &sync.Mutex{}}, + t: anyCall, + anyErrs: make([]error, len(promises)), } + // Extract the type + p.resultType = firstResultType[:] + p.counter = int64(1) + p.errCounter = int64(len(promises)) for i := range promises { go p.run(reflect.Value{}, nil, promises, i, nil) @@ -225,10 +306,10 @@ func (p *Promise) thenCall(prior *Promise, functionRv reflect.Value) []reflect.V if p.err != nil { panic(errors.Wrap(p.err, "error in previous promise")) } - results := functionRv.Call(prior.results) - if prior.returnsError && prior.err != nil { + if prior.err != nil { panic(prior.err) } + results := functionRv.Call(prior.results) return results } @@ -322,6 +403,11 @@ func (p *Promise) run(functionRv reflect.Value, prior *Promise, priors []*Promis } case anyCall: results = p.anyCall(priors, index) + if results == nil { + return + } + case raceCall: + results = p.raceCall(priors, index) default: panic("unexpected call type") } @@ -417,7 +503,7 @@ func (p *Promise) Wait(out ...interface{}) error { p.cond.L.Unlock() if p.err != nil { - return errors.Wrap(p.err, "panic() during promise execution") + return errors.Wrap(p.err, "error during promise execution") } var outRvs []reflect.Value diff --git a/promises_test.go b/promises_test.go index c5bef86..d84fc69 100644 --- a/promises_test.go +++ b/promises_test.go @@ -2,9 +2,11 @@ package promise import ( "errors" - "github.com/stretchr/testify/require" + "fmt" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestPromiseResolution(t *testing.T) { @@ -277,3 +279,121 @@ func TestErrorReturnExitsEarly(t *testing.T) { close(blocker) require.Error(t, err) } + +func TestPromiseRaceSucceedsIfOneSucceeds(t *testing.T) { + sleepThenPanic := func() string { + time.Sleep(100 * time.Millisecond) + panic("failed") + return "" + } + + sleepThenErr := func() (string, error) { + time.Sleep(100 * time.Millisecond) + return "", fmt.Errorf("err") + } + + success := func() string { + return "success" + } + + result := Race(New(sleepThenErr), New(sleepThenPanic), New(success)) + var retval string + err := result.Wait(&retval) + require.NoError(t, err) + require.Equal(t, "success", retval) +} + +func TestPromiseRaceFailsIfOneErrors(t *testing.T) { + sleepThenPanic := func() string { + time.Sleep(100 * time.Millisecond) + panic("failed") + return "" + } + + returnError := func() (string, error) { + return "", fmt.Errorf("err") + } + + sleepThenSuccess := func() string { + time.Sleep(100 * time.Millisecond) + return "success" + } + + result := Race(New(returnError), New(sleepThenPanic), New(sleepThenSuccess)) + var retval string + err := result.Wait(&retval) + require.Error(t, err) + require.Contains(t, err.Error(), "err") + require.Equal(t, "", retval) +} + +func TestPromiseRaceFailsIfOnePanics(t *testing.T) { + justPanic := func() string { + panic("failed") + return "" + } + + sleepThenError := func() (string, error) { + time.Sleep(100 * time.Millisecond) + return "", fmt.Errorf("err") + } + + sleepThenSuccess := func() string { + time.Sleep(100 * time.Millisecond) + return "success" + } + + result := Race(New(sleepThenError), New(justPanic), New(sleepThenSuccess)) + var retval string + err := result.Wait(&retval) + require.Error(t, err) + require.Contains(t, err.Error(), "failed") + require.Equal(t, "", retval) +} + +func TestPromiseAnySucceedsIfOneSucceeds(t *testing.T) { + sleepThenPanic := func() string { + time.Sleep(100 * time.Millisecond) + panic("failed") + return "" + } + + sleepThenErr := func() (string, error) { + time.Sleep(100 * time.Millisecond) + return "", fmt.Errorf("err") + } + + success := func() string { + return "success" + } + + result := Race(New(sleepThenErr), New(sleepThenPanic), New(success)) + var retval string + err := result.Wait(&retval) + require.NoError(t, err) + require.Equal(t, "success", retval) +} + +func TestPromiseAnyFailsIfAllFail(t *testing.T) { + sleepThenPanic := func() string { + time.Sleep(100 * time.Millisecond) + panic("failed") + return "" + } + + returnError := func() (string, error) { + return "", fmt.Errorf("err") + } + + sleepThenSuccess := func() string { + time.Sleep(100 * time.Millisecond) + return "success" + } + + result := Race(New(returnError), New(sleepThenPanic), New(sleepThenSuccess)) + var retval string + err := result.Wait(&retval) + require.Error(t, err) + require.Contains(t, err.Error(), "err") + require.Equal(t, "", retval) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index a18e2cf..00713f8 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -5,7 +5,7 @@ github.com/pkg/errors # github.com/pmezard/go-difflib v1.0.0 github.com/pmezard/go-difflib/difflib # github.com/stretchr/testify v1.4.0 -github.com/stretchr/testify/require github.com/stretchr/testify/assert +github.com/stretchr/testify/require # gopkg.in/yaml.v2 v2.2.2 gopkg.in/yaml.v2