diff --git a/donegroup.go b/donegroup.go index bf252d2..c4e6dbc 100644 --- a/donegroup.go +++ b/donegroup.go @@ -12,9 +12,7 @@ var ErrNotContainDoneGroup = errors.New("donegroup: context does not contain a d // doneGroup is cleanup function groups per Context. type doneGroup struct { - cancel context.CancelCauseFunc - // ctxw is the context used to call the cleanup functions - ctxw context.Context + cancel context.CancelCauseFunc cleanupGroups []*sync.WaitGroup errors error mu sync.Mutex @@ -178,17 +176,16 @@ func WaitWithContextAndKey(ctx, ctxw context.Context, key any) error { if !ok { return ErrNotContainDoneGroup } - dg.mu.Lock() - dg.ctxw = ctxw - dg.mu.Unlock() <-ctx.Done() wg := &sync.WaitGroup{} for _, g := range dg.cleanupGroups { wg.Add(1) + dg.mu.Lock() go func() { g.Wait() wg.Done() }() + dg.mu.Unlock() } ch := make(chan struct{}) go func() { diff --git a/donegroup_test.go b/donegroup_test.go index 4f632ed..b93ddd7 100644 --- a/donegroup_test.go +++ b/donegroup_test.go @@ -200,6 +200,7 @@ func TestNestedWithCancel(t *testing.T) { thirdCancel() <-thirdCtx.Done() + mu.Lock() if firstCleanup != 0 { t.Error("cleanup function for first called") } @@ -209,6 +210,7 @@ func TestNestedWithCancel(t *testing.T) { if thirdCleanup != 0 { t.Error("cleanup function for third called") } + mu.Unlock() secondCancel() <-secondCtx.Done() @@ -217,6 +219,7 @@ func TestNestedWithCancel(t *testing.T) { t.Error(err) } + mu.Lock() if thirdCleanup != 3 { t.Error("cleanup function for third not called") } @@ -226,6 +229,7 @@ func TestNestedWithCancel(t *testing.T) { if firstCleanup != 0 { t.Error("cleanup function for first called") } + mu.Unlock() firstCancel() <-firstCtx.Done() @@ -234,6 +238,7 @@ func TestNestedWithCancel(t *testing.T) { t.Error(err) } + mu.Lock() if thirdCleanup != 3 { t.Error("cleanup function for third not called") } @@ -243,6 +248,7 @@ func TestNestedWithCancel(t *testing.T) { if firstCleanup != 10 { t.Error("cleanup function for first not called") } + mu.Unlock() }() }