diff --git a/http/server.go b/http/server.go index 80d46c6679..f975e200ad 100644 --- a/http/server.go +++ b/http/server.go @@ -110,8 +110,10 @@ func WithTLSKeyPath(path string) ServerOpt { // Server struct holds the Handler for the HTTP API. type Server struct { - options *ServerOptions - server *http.Server + options *ServerOptions + server *http.Server + listener net.Listener + isTLS bool } // NewServer instantiates a new server with the given http.Handler. @@ -148,25 +150,34 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.server.Shutdown(ctx) } -// ListenAndServe listens for and serves incoming connections. -func (s *Server) ListenAndServe() error { +// SetListener sets a new listener on the Server. +func (s *Server) SetListener() (err error) { + s.listener, err = net.Listen("tcp", s.options.Address) + return err +} + +// Serve serves incoming connections. +func (s *Server) Serve() error { if s.options.TLSCertPath == "" && s.options.TLSKeyPath == "" { - return s.listenAndServe() + return s.serve() } - return s.listenAndServeTLS() + s.isTLS = true + return s.serveTLS() } -// listenAndServe listens for and serves http connections. -func (s *Server) listenAndServe() error { - listener, err := net.Listen("tcp", s.options.Address) - if err != nil { - return err +// serve serves http connections. +func (s *Server) serve() error { + if s.listener == nil { + return ErrNoListener } - return s.server.Serve(listener) + return s.server.Serve(s.listener) } -// listenAndServeTLS listens for and serves https connections. -func (s *Server) listenAndServeTLS() error { +// serveTLS serves https connections. +func (s *Server) serveTLS() error { + if s.listener == nil { + return ErrNoListener + } cert, err := tls.LoadX509KeyPair(s.options.TLSCertPath, s.options.TLSKeyPath) if err != nil { return err @@ -177,9 +188,12 @@ func (s *Server) listenAndServeTLS() error { CipherSuites: tlsCipherSuites, Certificates: []tls.Certificate{cert}, } - listener, err := net.Listen("tcp", s.options.Address) - if err != nil { - return err + return s.server.Serve(tls.NewListener(s.listener, config)) +} + +func (s *Server) Address() string { + if s.isTLS { + return "https://" + s.listener.Addr().String() } - return s.server.Serve(tls.NewListener(listener, config)) + return "http://" + s.listener.Addr().String() } diff --git a/http/server_test.go b/http/server_test.go index 4065267c26..ec9ab8ab75 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -65,11 +65,28 @@ var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusOK) }) +func TestServerServeWithNoListener(t *testing.T) { + srv, err := NewServer(testHandler) + require.NoError(t, err) + + err = srv.Serve() + require.ErrorIs(t, err, ErrNoListener) +} + +func TestServerServeWithTLSAndNoListener(t *testing.T) { + certPath, keyPath := writeTestCerts(t) + srv, err := NewServer(testHandler, WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) + require.NoError(t, err) + + err = srv.Serve() + require.ErrorIs(t, err, ErrNoListener) +} + func TestServerListenAndServeWithInvalidAddress(t *testing.T) { srv, err := NewServer(testHandler, WithAddress("invalid")) require.NoError(t, err) - err = srv.ListenAndServe() + err = srv.SetListener() require.ErrorContains(t, err, "address invalid") } @@ -78,16 +95,26 @@ func TestServerListenAndServeWithTLSAndInvalidAddress(t *testing.T) { srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath(certPath), WithTLSKeyPath(keyPath)) require.NoError(t, err) - err = srv.ListenAndServe() + err = srv.SetListener() require.ErrorContains(t, err, "address invalid") } func TestServerListenAndServeWithTLSAndInvalidCerts(t *testing.T) { - srv, err := NewServer(testHandler, WithAddress("invalid"), WithTLSCertPath("invalid.crt"), WithTLSKeyPath("invalid.key")) + srv, err := NewServer( + testHandler, + WithAddress("invalid"), + WithTLSCertPath("invalid.crt"), + WithTLSKeyPath("invalid.key"), + WithAddress("127.0.0.1:30001"), + ) require.NoError(t, err) - err = srv.ListenAndServe() + err = srv.SetListener() + require.NoError(t, err) + err = srv.Serve() require.ErrorContains(t, err, "no such file or directory") + err = srv.listener.Close() + require.NoError(t, err) } func TestServerListenAndServeWithAddress(t *testing.T) { @@ -95,7 +122,9 @@ func TestServerListenAndServeWithAddress(t *testing.T) { require.NoError(t, err) go func() { - err := srv.ListenAndServe() + err := srv.SetListener() + require.NoError(t, err) + err = srv.Serve() require.ErrorIs(t, http.ErrServerClosed, err) }() @@ -118,7 +147,9 @@ func TestServerListenAndServeWithTLS(t *testing.T) { require.NoError(t, err) go func() { - err := srv.ListenAndServe() + err := srv.SetListener() + require.NoError(t, err) + err = srv.Serve() require.ErrorIs(t, http.ErrServerClosed, err) }() @@ -140,7 +171,9 @@ func TestServerListenAndServeWithAllowedOrigins(t *testing.T) { require.NoError(t, err) go func() { - err := srv.ListenAndServe() + err := srv.SetListener() + require.NoError(t, err) + err = srv.Serve() require.ErrorIs(t, http.ErrServerClosed, err) }() diff --git a/net/peer.go b/net/peer.go index acdba2e9c8..0c456d5b18 100644 --- a/net/peer.go +++ b/net/peer.go @@ -177,6 +177,7 @@ func (p *Peer) Start() error { go p.handleBroadcastLoop() } + log.FeedbackInfo(p.ctx, "Starting P2P node", logging.NewKV("P2P addresses", p.host.Addrs())) // register the P2P gRPC server go func() { pb.RegisterServiceServer(p.p2pRPC, p.server) diff --git a/node/node.go b/node/node.go index f13fd0b1c2..89bedd56ff 100644 --- a/node/node.go +++ b/node/node.go @@ -12,6 +12,9 @@ package node import ( "context" + "errors" + "fmt" + gohttp "net/http" "github.com/libp2p/go-libp2p/core/peer" @@ -159,8 +162,13 @@ func (n *Node) Start(ctx context.Context) error { } } if n.Server != nil { + err := n.Server.SetListener() + if err != nil { + return err + } + log.FeedbackInfo(ctx, fmt.Sprintf("Providing HTTP API at %s.", n.Server.Address())) go func() { - if err := n.Server.ListenAndServe(); err != nil { + if err := n.Server.Serve(); err != nil && !errors.Is(err, gohttp.ErrServerClosed) { log.FeedbackErrorE(ctx, "HTTP server stopped", err) } }()