Skip to content

Commit

Permalink
Merge branch 'master' into dependabot/go_modules/github.com/stretchr/…
Browse files Browse the repository at this point in the history
…testify-1.9.0
  • Loading branch information
muir authored Jul 19, 2024
2 parents bde1bbe + dfca6cb commit 0c544e5
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 8 deletions.
60 changes: 52 additions & 8 deletions capture.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ import (
"log"
"runtime"
"strings"
"sync"
"time"

pkgerrors "github.com/pkg/errors"
)

// CaptureTimeout limits how long to wait for a capture ID to be returned from a capture handler.
var CaptureTimeout = 500 * time.Millisecond

type CaptureProvider string // i.e. "sentry"

type CaptureID string // may be a URL or any string that allows a captured error to be looked up
Expand Down Expand Up @@ -117,11 +121,12 @@ func alert(exception error) error {
// infinite recursion. Here, we try to prevent that. This is relatively expensive, but we're alerting, which
// shouldn't happen often.
pc := make([]uintptr, 42)
runtime.Callers(1, pc) // skip 1 (runtime.Callers)
runtime.Callers(1, pc) // skip 1 (the one skipped is runtime.Callers)
cf := runtime.CallersFrames(pc)
us, _ := cf.Next()
for them, ok := cf.Next(); ok; them, ok = cf.Next() {
if us.Func.Name() == them.Func.Name() {
// use HasPrefix here, not simple equality, because handlers are called from goroutine (below)
if strings.HasPrefix(them.Func.Name(), us.Func.Name()) {
log.Printf("cannot alert, recursion detected (%s): %+v", us.Func.Name(), exception)
return exception // don't recurse again
}
Expand Down Expand Up @@ -149,18 +154,57 @@ func alert(exception error) error {
return true
})

// Run handlers in goroutines, so that if one handler is deadlocked
// it does not prevent others from running, or us from returning.

timer := time.NewTimer(CaptureTimeout)
defer timer.Stop()

done := make(chan struct{})
finish := func() {close(done)}
var once sync.Once
var mu sync.Mutex

// start a goroutine for each handler
for provider, handler := range capture {
defer func() {
if r := recover(); r != nil {
log.Printf("failed to capture exception (%q): %+v", provider, r)
provider := provider
handler := handler
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("failed to capture exception (%q): %+v", provider, r)
}
}()

id := handler(exception, arg...)

mu.Lock()
defer mu.Unlock()
select {
case <-done:
// we are too late
default:
e.id[provider] = id
if len(e.id) == len(capture) {
once.Do(finish)
}
}
}()
}

id := handler(exception, arg...)
if id != "" {
e.id[provider] = id
// wait until done or timed out
waitLoop:
for {
select {
case <- timer.C:
mu.Lock()
once.Do(finish)
mu.Unlock()
case <- done:
break waitLoop
}
}

return e
}

Expand Down
48 changes: 48 additions & 0 deletions capture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package errors_test
import (
"fmt"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -102,3 +103,50 @@ func TestCaptureRecurse(t *testing.T) {
t.Errorf("alert did not capture")
}
}

func TestCaptureTimeout(t *testing.T) {
var called atomic.Uint64 // how many handlers have been called
var returned atomic.Uint64 // how many returned
n := 5 // how many slow handlers we will register
slow := errors.CaptureTimeout/time.Duration(n) // fastest duration of a slow handler

slowHandler := func(ex error, arg ...any) errors.CaptureID {
c := called.Add(1)
defer returned.Add(1)

// slow so that if multiple handlers are registered, capture will timeout
time.Sleep(time.Duration(c+1) * slow) // use count to make each handler slower than the one before
return errors.CaptureID(fmt.Sprintf("slowHandler %d", c))
}

for i := 0; i < n; i++ {
name := errors.CaptureProvider(fmt.Sprintf("slowHandler %d", i+1))
errors.RegisterCapture(name, slowHandler)
defer errors.UnregisterCapture(name)
}

beforeAlert := time.Now()
err := errors.Alertf(t.Name())
howLong := time.Since(beforeAlert)

// make sure we didn't wait much longer than CaptureTimeout
if howLong > errors.CaptureTimeout + (10 * time.Millisecond) {
t.Errorf("alert to %d handlers took longer than timeout by %s", n, howLong - errors.CaptureTimeout)
}

if int(called.Load()) != n {
t.Errorf("expected to call %d handlers, called %d", n, called.Load())
}

// we don't expect the alert to wait for all handlers
if returned.Load() >= called.Load() {
t.Error("alert waited for all slow handlers to return")
}

// some handlers should be fast enough that alert waits for them
if returned.Load() == 0 {
t.Errorf("alert did not wait for any handlers")
}

t.Log(err) // should show capture IDs returned from faster handlers, but not slower handlers
}

0 comments on commit 0c544e5

Please sign in to comment.