diff --git a/atreugo.go b/atreugo.go index 771a900..fdbd0d1 100644 --- a/atreugo.go +++ b/atreugo.go @@ -1,6 +1,7 @@ package atreugo import ( + "context" "log" "net" "os" @@ -272,3 +273,38 @@ func (s *Atreugo) NewVirtualHost(hostnames ...string) *Router { return vHost } + +// Shutdown gracefully shuts down the server without interrupting any active connections. +// Shutdown works by first closing all open listeners and then waiting indefinitely for +// all connections to return to idle and then shut down. +// +// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return +// nil. Make sure the program doesn't exit and waits instead for Shutdown to return. +// +// Shutdown does not close keepalive connections so it's recommended to set ReadTimeout +// and IdleTimeout to something else than 0. +func (s *Atreugo) Shutdown() (err error) { + if s.engine != nil { + err = s.engine.ShutdownWithContext(context.Background()) + } + + return +} + +// ShutdownWithContext gracefully shuts down the server without interrupting any active +// connections. ShutdownWithContext works by first closing all open listeners and then +// waiting for all connections to return to idle or context timeout and then shut down. +// +// When ShutdownWithContext is called, Serve, ListenAndServe, and ListenAndServeTLS +// immediately return nil. Make sure the program doesn't exit and waits instead for +// Shutdown to return. +// +// ShutdownWithContext does not close keepalive connections so it's recommended to set +// ReadTimeout and IdleTimeout to something else than 0. +func (s *Atreugo) ShutdownWithContext(ctx context.Context) (err error) { + if s.engine != nil { + err = s.engine.ShutdownWithContext(ctx) + } + + return +} diff --git a/atreugo_test.go b/atreugo_test.go index 0bea691..bb5869f 100644 --- a/atreugo_test.go +++ b/atreugo_test.go @@ -1,6 +1,7 @@ package atreugo import ( + "context" "crypto/tls" "errors" "fmt" @@ -591,6 +592,77 @@ func TestAtreugo_NewVirtualHost(t *testing.T) { //nolint:funlen } } +func TestAtreugo_Shutdown(t *testing.T) { + s := New(testConfig) + + ln := fasthttputil.NewInmemoryListener() + errCh := make(chan error, 1) + + go func() { + errCh <- s.Serve(ln) + }() + + time.Sleep(500 * time.Millisecond) + + if err := s.Shutdown(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := <-errCh; err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if lnAddr := ln.Addr().String(); s.cfg.Addr != lnAddr { + t.Errorf("Atreugo.Config.Addr = %s, want %s", s.cfg.Addr, lnAddr) + } + + lnNetwork := ln.Addr().Network() + if s.cfg.Network != lnNetwork { + t.Errorf("Atreugo.Config.Network = %s, want %s", s.cfg.Network, lnNetwork) + } + + if s.engine.Handler == nil { + t.Error("Atreugo.engine.Handler is nil") + } +} + +func TestAtreugo_ShutdownWithContext(t *testing.T) { + s := New(testConfig) + + ln := fasthttputil.NewInmemoryListener() + errCh := make(chan error, 1) + + go func() { + errCh <- s.Serve(ln) + }() + + time.Sleep(500 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + if err := s.ShutdownWithContext(ctx); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := <-errCh; err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if lnAddr := ln.Addr().String(); s.cfg.Addr != lnAddr { + t.Errorf("Atreugo.Config.Addr = %s, want %s", s.cfg.Addr, lnAddr) + } + + lnNetwork := ln.Addr().Network() + if s.cfg.Network != lnNetwork { + t.Errorf("Atreugo.Config.Network = %s, want %s", s.cfg.Network, lnNetwork) + } + + if s.engine.Handler == nil { + t.Error("Atreugo.engine.Handler is nil") + } +} + // Benchmarks. func Benchmark_Handler(b *testing.B) { s := New(testConfig)