From bf2a584788b20fd5cc69f2143454cf45f3556402 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Wed, 20 Sep 2023 01:00:34 +0200 Subject: [PATCH] Signal errors, context handler --- actors.go | 74 ++++++++++++++++++++++++++++++++++++++++++++------ actors_test.go | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 actors_test.go diff --git a/actors.go b/actors.go index d1bc754..10d50e2 100644 --- a/actors.go +++ b/actors.go @@ -2,23 +2,41 @@ package run import ( "context" + "errors" "fmt" "os" "os/signal" ) +// ContextHandler returns an actor, i.e. an execute and interrupt func, that +// terminates when the provided context is canceled. +func ContextHandler(ctx context.Context) (execute func() error, interrupt func(error)) { + ctx, cancel := context.WithCancel(ctx) + return func() error { + <-ctx.Done() + return ctx.Err() + }, func(error) { + cancel() + } +} + // SignalHandler returns an actor, i.e. an execute and interrupt func, that -// terminates with SignalError when the process receives one of the provided -// signals, or the parent context is canceled. +// terminates with ErrSignal when the process receives one of the provided +// signals, or with ctx.Error() when the parent context is canceled. If no +// signals are provided, the actor will terminate on any signal, per +// [signal.Notify]. func SignalHandler(ctx context.Context, signals ...os.Signal) (execute func() error, interrupt func(error)) { ctx, cancel := context.WithCancel(ctx) return func() error { - c := make(chan os.Signal, 1) - signal.Notify(c, signals...) - defer signal.Stop(c) + testc := getTestSigChan(ctx) + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, signals...) + defer signal.Stop(sigc) select { - case sig := <-c: - return SignalError{Signal: sig} + case sig := <-testc: + return &SignalError{Signal: sig} + case sig := <-sigc: + return &SignalError{Signal: sig} case <-ctx.Done(): return ctx.Err() } @@ -27,13 +45,51 @@ func SignalHandler(ctx context.Context, signals ...os.Signal) (execute func() er } } -// SignalError is returned by the signal handler's execute function -// when it terminates due to a received signal. +type testSigChanKey struct{} + +func getTestSigChan(ctx context.Context) <-chan os.Signal { + return ctx.Value(testSigChanKey{}).(<-chan os.Signal) // can be nil +} + +func putTestSigChan(ctx context.Context, c <-chan os.Signal) context.Context { + return context.WithValue(ctx, testSigChanKey{}, c) +} + +// SignalError is returned by the signal handler's execute function when it +// terminates due to a received signal. +// +// SignalError has a design error that impacts comparison with errors.As. +// Callers should prefer using errors.Is(err, ErrSignal) to check for signal +// errors, and should only use errors.As in the rare case that they need to +// program against the specific os.Signal value. type SignalError struct { Signal os.Signal } // Error implements the error interface. +// +// It was a design error to define this method on a value receiver rather than a +// pointer receiver. For compatibility reasons it won't be changed. func (e SignalError) Error() string { return fmt.Sprintf("received signal %s", e.Signal) } + +// Is addresses a design error in the SignalError type, so that errors.Is with +// ErrSignal will return true. +func (e SignalError) Is(err error) bool { + return errors.Is(err, ErrSignal) +} + +// As fixes a design error in the SignalError type, so that errors.As with the +// literal `&SignalError{}` will return true. +func (e SignalError) As(target interface{}) bool { + switch target.(type) { + case *SignalError, SignalError: + return true + default: + return false + } +} + +// ErrSignal is returned by SignalHandler when a signal triggers termination. +var ErrSignal = errors.New("signal error") diff --git a/actors_test.go b/actors_test.go new file mode 100644 index 0000000..bbb5efa --- /dev/null +++ b/actors_test.go @@ -0,0 +1,59 @@ +package run + +import ( + "context" + "errors" + "os" + "testing" + "time" +) + +func TestContextHandler(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var rg Group + rg.Add(ContextHandler(ctx)) + errc := make(chan error, 1) + go func() { errc <- rg.Run() }() + cancel() + select { + case err := <-errc: + if want, have := context.Canceled, err; !errors.Is(have, want) { + t.Errorf("error: want %v, have %v", want, have) + } + case <-time.After(time.Second): + t.Errorf("timeout waiting for error after cancel") + } +} + +func TestSignalError(t *testing.T) { + testc := make(chan os.Signal, 1) + ctx := putTestSigChan(context.Background(), testc) + + var rg Group + rg.Add(SignalHandler(ctx, os.Interrupt)) + testc <- os.Interrupt + err := rg.Run() + + var sigerr *SignalError + if want, have := true, errors.As(err, &sigerr); want != have { + t.Errorf("errors.As(err, &sigerr): want %v, have %v", want, have) + } + + if sigerr != nil { + if want, have := os.Interrupt, sigerr.Signal; want != have { + t.Errorf("sigerr.Signal: want %v, have %v", want, have) + } + } + + if sigerr := &(SignalError{}); !errors.As(err, &sigerr) { + t.Errorf("errors.As(err, ): failed") + } + + if want, have := true, errors.As(err, &(SignalError{})); want != have { + t.Errorf("errors.As(err, &(SignalError{})): want %v, have %v", want, have) + } + + if want, have := true, errors.Is(err, ErrSignal); want != have { + t.Errorf("errors.Is(err, ErrSignal): want %v, have %v", want, have) + } +}